Skip to content

Commit

Permalink
[WASM] Add Addn used by MobileNet (#2408)
Browse files Browse the repository at this point in the history
FEATURE

Add `Addn` used by MobileNet
  • Loading branch information
dsmilkov authored Nov 19, 2019
1 parent 324c1f1 commit 0818bd0
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 1 deletion.
10 changes: 10 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ tfjs_cc_library(
deps = [
":Abs",
":Add",
":AddN",
":BatchMatMul",
":MaxPool",
":ClipByValue",
Expand Down Expand Up @@ -154,6 +155,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "AddN",
srcs = ["kernels/AddN.cc"],
deps = [
":backend",
":util",
],
)

tfjs_cc_library(
name = "BatchMatMul",
srcs = ["kernels/BatchMatMul.cc"],
Expand Down
78 changes: 78 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/AddN.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <vector>

#include "src/cc/backend.h"
#include "src/cc/util.h"

namespace {

template <typename T>
void addn(const std::vector<const T*>& inputs_buf, const int size, T* out_buf) {
// Initialize the output to 0.
memset(out_buf, 0, size * sizeof(T));

for (size_t in_idx = 0; in_idx < inputs_buf.size(); ++in_idx) {
const T* input = inputs_buf[in_idx];
for (size_t i = 0; i < size; ++i) {
out_buf[i] += input[i];
}
}
}

} // namespace

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void AddN(const int* input_ids_ptr, const int input_ids_len, const DType dtype,
const int out_id) {
std::vector<int> inputs(input_ids_ptr, input_ids_ptr + input_ids_len);
auto& out_info = backend::get_tensor_info_out(out_id);
std::vector<void*> inputs_buf;
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(inputs_buf),
[](int id) { return backend::get_tensor_info(id).memory_offset; });

switch (dtype) {
case DType::float32:
addn<float>(reinterpret_cast<std::vector<const float*>&>(inputs_buf),
out_info.size, out_info.f32_write());
break;
case DType::int32:
addn<int>(reinterpret_cast<std::vector<const int*>&>(inputs_buf),
out_info.size, out_info.i32_write());
break;
case DType::boolean:
addn<bool>(reinterpret_cast<std::vector<const bool*>&>(inputs_buf),
out_info.size, out_info.b_write());
break;
default:
util::warn("AddN failed. Unknown dtype %d", dtype);
}
}

} // extern "C"
} // namespace wasm
} // namespace tfjs
58 changes: 58 additions & 0 deletions tfjs-backend-wasm/src/kernels/AddN.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelFunc, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';
import {CppDType} from './types';

let wasmFunc:
(inputIds: Uint8Array, inputIdsLen: number, dtype: number, outId: number) =>
void;

function setupFunc(backend: BackendWasm): void {
wasmFunc = backend.wasm.cwrap('AddN', null /* void */, [
'array', // input_ids
'number', // input_ids.length
'number', // dtype
'number', // out_id
]);
}

function addn(args: {inputs: TensorInfo[], backend: BackendWasm}) {
const {inputs, backend} = args;
const out = backend.makeOutput(inputs[0].shape, inputs[0].dtype);

// Short-circuit zero-sized tensors.
if (util.sizeFromShape(out.shape) === 0) {
return out;
}

const inputIds = inputs.map(x => backend.dataIdMap.get(x.dataId).id);
const inputIdsBytes = new Uint8Array(new Int32Array(inputIds).buffer);
const outId = backend.dataIdMap.get(out.dataId).id;
wasmFunc(inputIdsBytes, inputIds.length, CppDType[out.dtype], outId);

return out;
}

registerKernel({
kernelName: 'AddN',
backendName: 'wasm',
setupFunc,
kernelFunc: addn as {} as KernelFunc,
});
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/kernels/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
// the contents of this file and import only the kernels that are needed.
import './Abs';
import './Add';
import './AddN';
import './BatchMatMul';
import './Cast';
import './ClipByValue';
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ const TEST_FILTERS: TestFilter[] = [
},
{include: 'pad ', excludes: ['complex', 'zerosLike']},
{include: 'clip', excludes: ['gradient']},
{include: 'addN'},
];

const customInclude = (testName: string) => {
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/ops/binary_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ function addN_<T extends Tensor>(tensors: Array<T|TensorLike>): T {
return ders;
};
const inputs: NamedTensorMap = $tensors as {} as NamedTensorMap;
return ENGINE.runKernelFunc(backend => backend.addN($tensors), inputs, der);
return ENGINE.runKernelFunc(
backend => backend.addN($tensors), inputs, der, 'AddN');
}

/**
Expand Down

0 comments on commit 0818bd0

Please sign in to comment.