Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Optimizer MXKVStoreUpdater bug fix in serializeState method #14337

Merged

Conversation

satyakrishnagorti
Copy link
Contributor

@satyakrishnagorti satyakrishnagorti commented Mar 5, 2019

PR Fixes #14265

Description

Currently there is a bug in the way Optimizer is trying to serialize state which fails when trying to deserialize Optimizer that has no states (like SGD without momentum).

Issue

Currently the way serialize is being done is as below: (pasting Optimizer.serailizeState())

  override def serializeState(): Array[Byte] = {
        val bos = new ByteArrayOutputStream()
        try {
          val out = new ObjectOutputStream(bos)
          out.writeInt(states.size)
          states.foreach { case (k, v) =>
            if (v != null) {
              out.writeInt(k)
              val stateBytes = optimizer.serializeState(v)
              if (stateBytes == null) {
                out.writeInt(0)
              } else {
                out.writeInt(stateBytes.length)
                out.write(stateBytes)
              }
            }
          }
          out.flush()
          bos.toByteArray
        } finally {
         ...
      }
  }

When an Optimizer without states like SGD with momentum set as 0 is being used. The states map (Map[Int, AnyRef]) contains a (key, value) pair as (some integer index, null).

The above serialize method does not write k as the value of key and 0 as the value of stateBytes, due to the null check if (v != null)

Now while deserializing: (Pasting code from Optimizer.deserializeState())

  override def deserializeState(bytes: Array[Byte]): Unit = {
        val bis = new ByteArrayInputStream(bytes)
        var in: ObjectInputStream = null
        try {
          in = new ObjectInputStream(bis)
          val size = in.readInt()
          (0 until size).foreach(_ => {
            val key = in.readInt()
            val bytesLength = in.readInt()
            val value =
              if (bytesLength > 0) {
                val bytes = Array.fill[Byte](bytesLength)(0)
                in.readFully(bytes)
                optimizer.deserializeState(bytes)
              } else {
                null
              }
            states.update(key, value)
          })
        } finally {
          ...
      }
  }

In the foreach loop, the key is being read (which wasn't serialized previously) hence, this would cause an java.io.EOFException.

Solution.

Get rid of if (v != null) check and retain the rest.

@vandanavk
Copy link
Contributor

@mxnet-label-bot add [Scala, pr-awaiting-review]

@satyakrishnagorti Please add "Fixes #14265" to the PR description to close the issue automatically when the PR is merged.

@marcoabreu marcoabreu added pr-awaiting-review PR is waiting for code review Scala labels Mar 5, 2019
Copy link
Contributor

@zachgk zachgk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Thanks for your contribution.

Calling other Scala people to review: @piyushghai @andrewfayres @lanking520

@piyushghai
Copy link
Contributor

Do you reckon we could do with addition/updation of the Unit Test as well for this method ?

Copy link
Member

@lanking520 lanking520 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your contribution

@lanking520 lanking520 merged commit 12c41e6 into apache:master Mar 7, 2019
vdantu pushed a commit to vdantu/incubator-mxnet that referenced this pull request Mar 31, 2019
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review Scala
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants