diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala
new file mode 100644
index 000000000000..cfe290c1aff7
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.javaapi
+
+/**
+ * Layout definition of DataDesc
+ * N Batch size
+ * C channels
+ * H Height
+ * W Weight
+ * T sequence length
+ * __undefined__ default value of Layout
+ */
+object Layout {
+ val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED
+ val NCHW: String = org.apache.mxnet.Layout.NCHW
+ val NTC: String = org.apache.mxnet.Layout.NTC
+ val NT: String = org.apache.mxnet.Layout.NT
+ val N: String = org.apache.mxnet.Layout.N
+}
diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml
index 07e430175c6f..d60782ffd06b 100644
--- a/scala-package/examples/pom.xml
+++ b/scala-package/examples/pom.xml
@@ -145,5 +145,10 @@
slf4j-simple
1.7.5
+
+ com.google.code.gson
+ gson
+ 2.8.5
+
diff --git a/scala-package/examples/scripts/infer/bert/get_bert_data.sh b/scala-package/examples/scripts/infer/bert/get_bert_data.sh
new file mode 100755
index 000000000000..609aae27cc66
--- /dev/null
+++ b/scala-package/examples/scripts/infer/bert/get_bert_data.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd)
+
+data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/
+
+if [ ! -d "$data_path" ]; then
+ mkdir -p "$data_path"
+ curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json
+ curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params
+ curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json
+fi
diff --git a/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh b/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh
new file mode 100755
index 000000000000..d8ba092c5c1b
--- /dev/null
+++ b/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../../../.."; pwd)
+
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*
+
+java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \
+ org.apache.mxnetexamples.javaapi.infer.bert.BertQA $@
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
new file mode 100644
index 000000000000..440670afc098
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.bert;
+
+import java.io.FileReader;
+import java.util.*;
+
+import com.google.gson.Gson;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+
+/**
+ * This is the Utility for pre-processing the data for Bert Model
+ * You can use this utility to parse Vocabulary JSON into Java Array and Dictionary,
+ * clean and tokenize sentences and pad the text
+ */
+public class BertDataParser {
+
+ private Map token2idx;
+ private List idx2token;
+
+ /**
+ * Parse the Vocabulary to JSON files
+ * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens
+ * @param jsonFile the filePath of the vocab.json
+ * @throws Exception
+ */
+ void parseJSON(String jsonFile) throws Exception {
+ Gson gson = new Gson();
+ token2idx = new HashMap<>();
+ idx2token = new LinkedList<>();
+ JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), JsonObject.class);
+ JsonArray arr = jsonObject.getAsJsonArray("idx_to_token");
+ for (JsonElement element : arr) {
+ idx2token.add(element.getAsString());
+ }
+ JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx");
+ for (String key : preMap.keySet()) {
+ token2idx.put(key, preMap.get(key).getAsInt());
+ }
+ }
+
+ /**
+ * Tokenize the input, split all kinds of whitespace and
+ * Separate the end of sentence symbol: . , ? !
+ * @param input The input string
+ * @return List of tokens
+ */
+ List tokenizer(String input) {
+ String[] step1 = input.split("\\s+");
+ List finalResult = new LinkedList<>();
+ for (String item : step1) {
+ if (item.length() != 0) {
+ if ((item + "a").split("[.,?!]+").length > 1) {
+ finalResult.add(item.substring(0, item.length() - 1));
+ finalResult.add(item.substring(item.length() -1));
+ } else {
+ finalResult.add(item);
+ }
+ }
+ }
+ return finalResult;
+ }
+
+ /**
+ * Pad the tokens to the required length
+ * @param tokens input tokens
+ * @param padItem things to pad at the end
+ * @param num total length after padding
+ * @return List of padded tokens
+ */
+ List pad(List tokens, E padItem, int num) {
+ if (tokens.size() >= num) return tokens;
+ List padded = new LinkedList<>(tokens);
+ for (int i = 0; i < num - tokens.size(); i++) {
+ padded.add(padItem);
+ }
+ return padded;
+ }
+
+ /**
+ * Convert tokens to indexes
+ * @param tokens input tokens
+ * @return List of indexes
+ */
+ List token2idx(List tokens) {
+ List indexes = new ArrayList<>();
+ for (String token : tokens) {
+ if (token2idx.containsKey(token)) {
+ indexes.add(token2idx.get(token));
+ } else {
+ indexes.add(token2idx.get("[UNK]"));
+ }
+ }
+ return indexes;
+ }
+
+ /**
+ * Convert indexes to tokens
+ * @param indexes List of indexes
+ * @return List of tokens
+ */
+ List idx2token(List indexes) {
+ List tokens = new ArrayList<>();
+ for (int index : indexes) {
+ tokens.add(idx2token.get(index));
+ }
+ return tokens;
+ }
+}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
new file mode 100644
index 000000000000..b40a4e94afbd
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.bert;
+
+import org.apache.mxnet.infer.javaapi.Predictor;
+import org.apache.mxnet.javaapi.*;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+
+/**
+ * This is an example of using BERT to do the general Question and Answer inference jobs
+ * Users can provide a question with a paragraph contains answer to the model and
+ * the model will be able to find the best answer from the answer paragraph
+ */
+public class BertQA {
+ @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
+ private String modelPathPrefix = "/model/static_bert_qa";
+ @Option(name = "--model-epoch", usage = "Epoch number of the model")
+ private int epoch = 2;
+ @Option(name = "--model-vocab", usage = "the vocabulary used in the model")
+ private String modelVocab = "/model/vocab.json";
+ @Option(name = "--input-question", usage = "the input question")
+ private String inputQ = "When did BBC Japan start broadcasting?";
+ @Option(name = "--input-answer", usage = "the input answer")
+ private String inputA =
+ "BBC Japan was a general entertainment Channel.\n" +
+ " Which operated between December 2004 and April 2006.\n" +
+ "It ceased operations after its Japanese distributor folded.";
+ @Option(name = "--seq-length", usage = "the maximum length of the sequence")
+ private int seqLength = 384;
+
+ private final static Logger logger = LoggerFactory.getLogger(BertQA.class);
+ private static NDArray$ NDArray = NDArray$.MODULE$;
+
+ private static int argmax(float[] prob) {
+ int maxIdx = 0;
+ for (int i = 0; i < prob.length; i++) {
+ if (prob[maxIdx] < prob[i]) maxIdx = i;
+ }
+ return maxIdx;
+ }
+
+ /**
+ * Do the post processing on the output, apply softmax to get the probabilities
+ * reshape and get the most probable index
+ * @param result prediction result
+ * @param tokens word tokens
+ * @return Answers clipped from the original paragraph
+ */
+ static List postProcessing(NDArray result, List tokens) {
+ NDArray[] output = NDArray.split(
+ NDArray.new splitParam(result, 2).setAxis(2));
+ // Get the formatted logits result
+ NDArray startLogits = output[0].reshape(new int[]{0, -3});
+ NDArray endLogits = output[1].reshape(new int[]{0, -3});
+ // Get Probability distribution
+ float[] startProb = NDArray.softmax(
+ NDArray.new softmaxParam(startLogits))[0].toArray();
+ float[] endProb = NDArray.softmax(
+ NDArray.new softmaxParam(endLogits))[0].toArray();
+ int startIdx = argmax(startProb);
+ int endIdx = argmax(endProb);
+ return tokens.subList(startIdx, endIdx + 1);
+ }
+
+ public static void main(String[] args) throws Exception{
+ BertQA inst = new BertQA();
+ CmdLineParser parser = new CmdLineParser(inst);
+ parser.parseArgument(args);
+ BertDataParser util = new BertDataParser();
+ Context context = Context.cpu();
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+ context = Context.gpu();
+ }
+ // pre-processing - tokenize sentence
+ List tokenQ = util.tokenizer(inst.inputQ.toLowerCase());
+ List tokenA = util.tokenizer(inst.inputA.toLowerCase());
+ int validLength = tokenQ.size() + tokenA.size();
+ logger.info("Valid length: " + validLength);
+ // generate token types [0000...1111....0000]
+ List QAEmbedded = new ArrayList<>();
+ util.pad(QAEmbedded, 0f, tokenQ.size()).addAll(
+ util.pad(new ArrayList(), 1f, tokenA.size())
+ );
+ List tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength);
+ // make BERT pre-processing standard
+ tokenQ.add("[SEP]");
+ tokenQ.add(0, "[CLS]");
+ tokenA.add("[SEP]");
+ tokenQ.addAll(tokenA);
+ List tokens = util.pad(tokenQ, "[PAD]", inst.seqLength);
+ logger.info("Pre-processed tokens: " + Arrays.toString(tokenQ.toArray()));
+ // pre-processing - token to index translation
+ util.parseJSON(inst.modelVocab);
+ List indexes = util.token2idx(tokens);
+ List indexesFloat = new ArrayList<>();
+ for (int integer : indexes) {
+ indexesFloat.add((float) integer);
+ }
+ // Preparing the input data
+ List inputBatch = Arrays.asList(
+ new NDArray(indexesFloat,
+ new Shape(new int[]{1, inst.seqLength}), context),
+ new NDArray(tokenTypes,
+ new Shape(new int[]{1, inst.seqLength}), context),
+ new NDArray(new float[] { validLength },
+ new Shape(new int[]{1}), context)
+ );
+ // Build the model
+ List contexts = new ArrayList<>();
+ contexts.add(context);
+ List inputDescs = Arrays.asList(
+ new DataDesc("data0",
+ new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()),
+ new DataDesc("data1",
+ new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()),
+ new DataDesc("data2",
+ new Shape(new int[]{1}), DType.Float32(), Layout.N())
+ );
+ Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch);
+ // Start prediction
+ NDArray result = bertQA.predictWithNDArray(inputBatch).get(0);
+ List answer = postProcessing(result, tokens);
+ logger.info("Question: " + inst.inputQ);
+ logger.info("Answer paragraph: " + inst.inputA);
+ logger.info("Answer: " + Arrays.toString(answer.toArray()));
+ }
+}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md
new file mode 100644
index 000000000000..7925a259f48f
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md
@@ -0,0 +1,103 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# Run BERT QA model using Java Inference API
+
+In this tutorial, we will walk through the BERT QA model trained by MXNet.
+Users can provide a question with a paragraph contains answer to the model and
+the model will be able to find the best answer from the answer paragraph.
+
+Example:
+```text
+Q: When did BBC Japan start broadcasting?
+```
+
+Answer paragraph
+```text
+BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.
+It ceased operations after its Japanese distributor folded.
+```
+And it picked up the right one:
+```text
+A: December 2004
+```
+
+## Setup Guide
+
+### Step 1: Download the model
+
+For this tutorial, you can get the model and vocabulary by running following bash file. This script will use `wget` to download these artifacts from AWS S3.
+
+From the `scala-package/examples/scripts/infer/bert/` folder run:
+
+```bash
+./get_bert_data.sh
+```
+
+### Step 2: Setup data path of the model
+
+### Setup Datapath and Parameters
+
+The available arguments are as follows:
+
+| Argument | Comments |
+| ----------------------------- | ---------------------------------------- |
+| `--model-path-prefix` | Folder path with prefix to the model (including json, params). |
+| `--model-vocab` | Vocabulary path |
+| `--model-epoch` | Epoch number of the model |
+| `--input-question` | Question that asked to the model |
+| `--input-answer` | Paragraph that contains the answer |
+| `--seq-length` | Sequence Length of the model (384 by default) |
+
+### Step 3: Run Inference
+After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Infer API.
+
+From the `scala-package/examples/scripts/infer/bert/` folder run:
+
+```bash
+./run_bert_qa_example.sh --model-path-prefix ../models/static-bert-qa/static_bert_qa \
+ --model-vocab ../models/static-bert-qa/vocab.json \
+ --model-epoch 2
+```
+
+## Background
+
+To learn more about how BERT works in MXNet, please follow this [MXNet Gluon tutorial on NLP using BERT](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340).
+
+The model was extracted from MXNet GluonNLP with static length settings.
+
+[Download link for the script](https://gluon-nlp.mxnet.io/_downloads/bert.zip)
+
+The original description can be found in the [MXNet GluonNLP model zoo](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1).
+```bash
+python static_finetune_squad.py --optimizer adam --accumulate 2 --batch_size 6 --lr 3e-5 --epochs 2 --gpu 0 --export
+
+```
+This script will generate `json` and `param` fles that are the standard MXNet model files.
+By default, this model are using `bert_12_768_12` model with extra layers for QA jobs.
+
+After that, to be able to use it in Java, we need to export the dictionary from the script to parse the text
+to actual indexes. Please add the following lines after [this line](/~https://github.com/dmlc/gluon-nlp/blob/master/scripts/bert/staticbert/static_finetune_squad.py#L262).
+```python
+import json
+json_str = vocab.to_json()
+f = open("vocab.json", "w")
+f.write(json_str)
+f.close()
+```
+This would export the token vocabulary in json format.
+Once you have these three files, you will be able to run this example without problems.
diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java
new file mode 100644
index 000000000000..0518254c297d
--- /dev/null
+++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.predictor;
+
+import org.apache.mxnetexamples.Util;
+import org.apache.mxnetexamples.javaapi.infer.bert.BertQA;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+
+/**
+ * Test on BERT QA model
+ */
+public class BertExampleTest {
+ final static Logger logger = LoggerFactory.getLogger(BertExampleTest.class);
+ private static String modelPathPrefix = "";
+ private static String vocabPath = "";
+
+ @BeforeClass
+ public static void downloadFile() {
+ logger.info("Downloading Bert QA Model");
+ String tempDirPath = System.getProperty("java.io.tmpdir");
+ logger.info("tempDirPath: %s".format(tempDirPath));
+
+ String baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA";
+ Util.downloadUrl(baseUrl + "/static_bert_qa-symbol.json",
+ tempDirPath + "/static_bert_qa/static_bert_qa-symbol.json", 3);
+ Util.downloadUrl(baseUrl + "/static_bert_qa-0002.params",
+ tempDirPath + "/static_bert_qa/static_bert_qa-0002.params", 3);
+ Util.downloadUrl(baseUrl + "/vocab.json",
+ tempDirPath + "/static_bert_qa/vocab.json", 3);
+ modelPathPrefix = tempDirPath + File.separator + "static_bert_qa/static_bert_qa";
+ vocabPath = tempDirPath + File.separator + "static_bert_qa/vocab.json";
+ }
+
+ @Test
+ public void testBertQA() throws Exception{
+ BertQA bert = new BertQA();
+ String Q = "When did BBC Japan start broadcasting?";
+ String A = "BBC Japan was a general entertainment Channel.\n" +
+ " Which operated between December 2004 and April 2006.\n" +
+ "It ceased operations after its Japanese distributor folded.";
+ String[] args = new String[] {
+ "--model-path-prefix", modelPathPrefix,
+ "--model-vocab", vocabPath,
+ "--model-epoch", "2",
+ "--input-question", Q,
+ "--input-answer", A,
+ "--seq-length", "384"
+ };
+ bert.main(args);
+ }
+}