Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add BERT QA Scala/Java example (#14592)
Browse files Browse the repository at this point in the history
* add BertQA major code piece

* add scripts and bug fixes

* add integration test

* address comments

* address doc comments
  • Loading branch information
lanking520 authored Apr 5, 2019
1 parent 43f7c12 commit d5d1d7a
Show file tree
Hide file tree
Showing 8 changed files with 545 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mxnet.javaapi

/**
* Layout definition of DataDesc
* N Batch size
* C channels
* H Height
* W Weight
* T sequence length
* __undefined__ default value of Layout
*/
object Layout {
val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED
val NCHW: String = org.apache.mxnet.Layout.NCHW
val NTC: String = org.apache.mxnet.Layout.NTC
val NT: String = org.apache.mxnet.Layout.NT
val N: String = org.apache.mxnet.Layout.N
}
5 changes: 5 additions & 0 deletions scala-package/examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,10 @@
<artifactId>slf4j-simple</artifactId>
<version>1.7.5</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
</dependencies>
</project>
31 changes: 31 additions & 0 deletions scala-package/examples/scripts/infer/bert/get_bert_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -e

MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd)

data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/

if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json
fi
27 changes: 27 additions & 0 deletions scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -e

MXNET_ROOT=$(cd "$(dirname $0)/../../../../.."; pwd)

CLASS_PATH=$MXNET_ROOT/scala-package/assembly/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*

java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \
org.apache.mxnetexamples.javaapi.infer.bert.BertQA $@
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.javaapi.infer.bert;

import java.io.FileReader;
import java.util.*;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

/**
* This is the Utility for pre-processing the data for Bert Model
* You can use this utility to parse Vocabulary JSON into Java Array and Dictionary,
* clean and tokenize sentences and pad the text
*/
public class BertDataParser {

private Map<String, Integer> token2idx;
private List<String> idx2token;

/**
* Parse the Vocabulary to JSON files
* [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens
* @param jsonFile the filePath of the vocab.json
* @throws Exception
*/
void parseJSON(String jsonFile) throws Exception {
Gson gson = new Gson();
token2idx = new HashMap<>();
idx2token = new LinkedList<>();
JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), JsonObject.class);
JsonArray arr = jsonObject.getAsJsonArray("idx_to_token");
for (JsonElement element : arr) {
idx2token.add(element.getAsString());
}
JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx");
for (String key : preMap.keySet()) {
token2idx.put(key, preMap.get(key).getAsInt());
}
}

/**
* Tokenize the input, split all kinds of whitespace and
* Separate the end of sentence symbol: . , ? !
* @param input The input string
* @return List of tokens
*/
List<String> tokenizer(String input) {
String[] step1 = input.split("\\s+");
List<String> finalResult = new LinkedList<>();
for (String item : step1) {
if (item.length() != 0) {
if ((item + "a").split("[.,?!]+").length > 1) {
finalResult.add(item.substring(0, item.length() - 1));
finalResult.add(item.substring(item.length() -1));
} else {
finalResult.add(item);
}
}
}
return finalResult;
}

/**
* Pad the tokens to the required length
* @param tokens input tokens
* @param padItem things to pad at the end
* @param num total length after padding
* @return List of padded tokens
*/
<E> List<E> pad(List<E> tokens, E padItem, int num) {
if (tokens.size() >= num) return tokens;
List<E> padded = new LinkedList<>(tokens);
for (int i = 0; i < num - tokens.size(); i++) {
padded.add(padItem);
}
return padded;
}

/**
* Convert tokens to indexes
* @param tokens input tokens
* @return List of indexes
*/
List<Integer> token2idx(List<String> tokens) {
List<Integer> indexes = new ArrayList<>();
for (String token : tokens) {
if (token2idx.containsKey(token)) {
indexes.add(token2idx.get(token));
} else {
indexes.add(token2idx.get("[UNK]"));
}
}
return indexes;
}

/**
* Convert indexes to tokens
* @param indexes List of indexes
* @return List of tokens
*/
List<String> idx2token(List<Integer> indexes) {
List<String> tokens = new ArrayList<>();
for (int index : indexes) {
tokens.add(idx2token.get(index));
}
return tokens;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.javaapi.infer.bert;

import org.apache.mxnet.infer.javaapi.Predictor;
import org.apache.mxnet.javaapi.*;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

/**
* This is an example of using BERT to do the general Question and Answer inference jobs
* Users can provide a question with a paragraph contains answer to the model and
* the model will be able to find the best answer from the answer paragraph
*/
public class BertQA {
@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
private String modelPathPrefix = "/model/static_bert_qa";
@Option(name = "--model-epoch", usage = "Epoch number of the model")
private int epoch = 2;
@Option(name = "--model-vocab", usage = "the vocabulary used in the model")
private String modelVocab = "/model/vocab.json";
@Option(name = "--input-question", usage = "the input question")
private String inputQ = "When did BBC Japan start broadcasting?";
@Option(name = "--input-answer", usage = "the input answer")
private String inputA =
"BBC Japan was a general entertainment Channel.\n" +
" Which operated between December 2004 and April 2006.\n" +
"It ceased operations after its Japanese distributor folded.";
@Option(name = "--seq-length", usage = "the maximum length of the sequence")
private int seqLength = 384;

private final static Logger logger = LoggerFactory.getLogger(BertQA.class);
private static NDArray$ NDArray = NDArray$.MODULE$;

private static int argmax(float[] prob) {
int maxIdx = 0;
for (int i = 0; i < prob.length; i++) {
if (prob[maxIdx] < prob[i]) maxIdx = i;
}
return maxIdx;
}

/**
* Do the post processing on the output, apply softmax to get the probabilities
* reshape and get the most probable index
* @param result prediction result
* @param tokens word tokens
* @return Answers clipped from the original paragraph
*/
static List<String> postProcessing(NDArray result, List<String> tokens) {
NDArray[] output = NDArray.split(
NDArray.new splitParam(result, 2).setAxis(2));
// Get the formatted logits result
NDArray startLogits = output[0].reshape(new int[]{0, -3});
NDArray endLogits = output[1].reshape(new int[]{0, -3});
// Get Probability distribution
float[] startProb = NDArray.softmax(
NDArray.new softmaxParam(startLogits))[0].toArray();
float[] endProb = NDArray.softmax(
NDArray.new softmaxParam(endLogits))[0].toArray();
int startIdx = argmax(startProb);
int endIdx = argmax(endProb);
return tokens.subList(startIdx, endIdx + 1);
}

public static void main(String[] args) throws Exception{
BertQA inst = new BertQA();
CmdLineParser parser = new CmdLineParser(inst);
parser.parseArgument(args);
BertDataParser util = new BertDataParser();
Context context = Context.cpu();
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
context = Context.gpu();
}
// pre-processing - tokenize sentence
List<String> tokenQ = util.tokenizer(inst.inputQ.toLowerCase());
List<String> tokenA = util.tokenizer(inst.inputA.toLowerCase());
int validLength = tokenQ.size() + tokenA.size();
logger.info("Valid length: " + validLength);
// generate token types [0000...1111....0000]
List<Float> QAEmbedded = new ArrayList<>();
util.pad(QAEmbedded, 0f, tokenQ.size()).addAll(
util.pad(new ArrayList<Float>(), 1f, tokenA.size())
);
List<Float> tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength);
// make BERT pre-processing standard
tokenQ.add("[SEP]");
tokenQ.add(0, "[CLS]");
tokenA.add("[SEP]");
tokenQ.addAll(tokenA);
List<String> tokens = util.pad(tokenQ, "[PAD]", inst.seqLength);
logger.info("Pre-processed tokens: " + Arrays.toString(tokenQ.toArray()));
// pre-processing - token to index translation
util.parseJSON(inst.modelVocab);
List<Integer> indexes = util.token2idx(tokens);
List<Float> indexesFloat = new ArrayList<>();
for (int integer : indexes) {
indexesFloat.add((float) integer);
}
// Preparing the input data
List<NDArray> inputBatch = Arrays.asList(
new NDArray(indexesFloat,
new Shape(new int[]{1, inst.seqLength}), context),
new NDArray(tokenTypes,
new Shape(new int[]{1, inst.seqLength}), context),
new NDArray(new float[] { validLength },
new Shape(new int[]{1}), context)
);
// Build the model
List<Context> contexts = new ArrayList<>();
contexts.add(context);
List<DataDesc> inputDescs = Arrays.asList(
new DataDesc("data0",
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()),
new DataDesc("data1",
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()),
new DataDesc("data2",
new Shape(new int[]{1}), DType.Float32(), Layout.N())
);
Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch);
// Start prediction
NDArray result = bertQA.predictWithNDArray(inputBatch).get(0);
List<String> answer = postProcessing(result, tokens);
logger.info("Question: " + inst.inputQ);
logger.info("Answer paragraph: " + inst.inputA);
logger.info("Answer: " + Arrays.toString(answer.toArray()));
}
}
Loading

0 comments on commit d5d1d7a

Please sign in to comment.