Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Add AvgPool kernel. #2411

Merged
merged 14 commits into from
Nov 20, 2019
5 changes: 4 additions & 1 deletion tfjs-backend-wasm/karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ const karmaTypescriptConfig = {
sourceMap: true,
// Ignore the import of the `worker_threads` package used in a core test
// meant to run in node.
exclude: ['worker_threads']
exclude: ['worker_threads'],
// worker_node_test in tfjs-core contains a conditional require statement
// that confuses the bundler of karma-typescript.
ignore: ['./worker_node_test']
},
// Disable coverage reports and instrumentation by default for tests
coverageOptions: {instrumentation: false},
Expand Down
19 changes: 19 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ tfjs_cc_library(
deps = [
":Abs",
":Add",
":AvgPool",
":AddN",
":BatchMatMul",
":MaxPool",
Expand Down Expand Up @@ -109,6 +110,16 @@ tfjs_cc_library(
]
)

tfjs_cc_library(
name = "AvgPool",
srcs = ["kernels/AvgPool.cc"],
hdrs = ["kernels/AvgPool.h"],
deps = [
":backend",
":util"
]
)

tfjs_cc_library(
name = "FusedBatchNorm",
srcs = ["kernels/FusedBatchNorm.cc"],
Expand Down Expand Up @@ -348,6 +359,14 @@ tfjs_unit_test(
]
)

tfjs_unit_test(
name = "AvgPool_test",
srcs = ["kernels/AvgPool_test.cc"],
deps = [
":AvgPool"
]
)

tfjs_unit_test(
name = "MaxPool_test",
srcs = ["kernels/MaxPool_test.cc"],
Expand Down
107 changes: 107 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/AvgPool.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <xnnpack.h>
#include <array>
#include <cmath>
#include <limits>
#include <map>
#include <unordered_map>

#include "src/cc/backend.h"
#include "src/cc/kernels/AvgPool.h"
#include "src/cc/util.h"

namespace {
typedef std::array<int, 14> OperatorCacheKey;

std::map<OperatorCacheKey, xnn_operator_t> operator_cache;
} // namespace

namespace tfjs {
namespace wasm {
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void AvgPool(const int x_id, const int batch_size, const int input_height,
const int input_width, const int filter_height,
const int filter_width, int pad_top, int pad_right, int pad_bottom,
int pad_left, const int stride_height, const int stride_width,
const int channels, const int out_id) {
auto& x_info = backend::get_tensor_info(x_id);
auto& out_info = backend::get_tensor_info(out_id);

const float* x_buf = reinterpret_cast<float*>(x_info.memory_offset);
float* out_buf = reinterpret_cast<float*>(out_info.memory_offset);

xnn_operator_t avg_pool_op = nullptr;

const int flags = 0;
const int input_pixel_stride = channels;
const int output_pixel_stride = channels;

OperatorCacheKey cache_key = {
pad_top, pad_right, pad_bottom, pad_left,
filter_height, filter_width, stride_height, stride_width,
channels, input_pixel_stride, output_pixel_stride, flags};

auto operator_cache_idx = operator_cache.find(cache_key);

if (operator_cache_idx == operator_cache.end()) {
float output_min = -std::numeric_limits<float>::infinity();
float output_max = std::numeric_limits<float>::infinity();

xnn_status status = xnn_create_average_pooling2d_nhwc_f32(
pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width,
stride_height, stride_width, channels, input_pixel_stride,
output_pixel_stride, output_min, output_max, flags, &avg_pool_op);

if (status != xnn_status_success) {
util::warn(
"XNN status for xnn_create_average_pooling2d_nhwc_f32 is not "
"successful. ",
"Got status %d. Use -c dbg to see XNN logs.", status);
return;
}

operator_cache.emplace(cache_key, avg_pool_op);

tfjs::backend::xnn_operator_count++;
} else {
avg_pool_op = operator_cache_idx->second;
}

xnn_status status = xnn_setup_average_pooling2d_nhwc_f32(
avg_pool_op, batch_size, input_height, input_width, x_buf, out_buf,
nullptr /* thread pool */);
if (status != xnn_status_success) {
util::warn(
"XNN status for xnn_setup_average_pooling2d_nhwc_f32 is not "
"successful. "
"Got status %d. Use -c dbg to see XNN logs.",
status);
return;
}

xnn_run_operator(avg_pool_op, nullptr /* thread pool */);
}
} // extern "C"
} // namespace wasm
} // namespace tfjs
32 changes: 32 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/AvgPool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifndef KERNELS_AVGPOOL_H_
#define KERNELS_AVGPOOL_H_

namespace tfjs {

namespace wasm {
extern "C" {
void AvgPool(const int x_id, const int batch_size, const int input_height,
const int input_width, const int filter_height,
const int filter_width, int pad_top, int pad_right, int pad_bottom,
int pad_left, const int stride_height, const int stride_width,
const int channels, const int out_id);
}

} // namespace wasm
} // namespace tfjs

#endif // KERNELS_AVGPOOL_H_
78 changes: 78 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/AvgPool_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@

/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#include <gtest/gtest.h>

#include "src/cc/backend.h"
#include "src/cc/kernels/AvgPool.h"

TEST(MAXPOOL, xnn_operator_lifetime) {
tfjs::wasm::init();

ASSERT_EQ(0, tfjs::backend::num_tensors());

const int x0_id = 0;
const int x1_id = 1;
const int size = 9;
float x_values[size] = {1, 2, 3, 4, 5, 6, 7, 8, 9};

const int out_id = 2;
const int out_size = 9;
float out_values[out_size] = {};

tfjs::wasm::register_tensor(x0_id, size, x_values);
tfjs::wasm::register_tensor(x1_id, size, x_values);
tfjs::wasm::register_tensor(out_id, out_size, out_values);

ASSERT_EQ(3, tfjs::backend::num_tensors());
ASSERT_EQ(0, tfjs::backend::xnn_operator_count);

// One xnn_operator should be created for first call to avgPool.
const int batch_size = 1;
const int input_height = 3;
const int input_width = 3;
const int filter_height = 2;
const int filter_width = 2;
const int pad_top = 0;
const int pad_right = 1;
const int pad_bottom = 1;
const int pad_left = 0;
const int stride_height = 1;
const int stride_width = 1;
const int channels = 1;
tfjs::wasm::AvgPool(x0_id, batch_size, input_height, input_width,
filter_height, filter_width, pad_top, pad_right,
pad_bottom, pad_left, stride_height, stride_width,
channels, out_id);
ASSERT_EQ(1, tfjs::backend::xnn_operator_count);

// No new xnn_operators should be created for the second call to avgPool with
// the same arguments.
tfjs::wasm::AvgPool(x0_id, batch_size, input_height, input_width,
filter_height, filter_width, pad_top, pad_right,
pad_bottom, pad_left, stride_height, stride_width,
channels, out_id);
ASSERT_EQ(1, tfjs::backend::xnn_operator_count);

// One new xnn_operator should be created for the next call to avgPool with
// 'valid' padding.
tfjs::wasm::AvgPool(x0_id, batch_size, input_height, input_width,
filter_height, filter_width, pad_top, 0 /* pad_right */,
0 /* pad_bottom */, pad_left, stride_height, stride_width,
channels, out_id);
ASSERT_EQ(2, tfjs::backend::xnn_operator_count);

tfjs::wasm::dispose();
}
100 changes: 100 additions & 0 deletions tfjs-backend-wasm/src/kernels/AvgPool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util, KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface AvgPoolInputs extends NamedTensorInfoMap {
x: TensorInfo;
filter: TensorInfo;
}

let wasmAvgPool: (
xId: number, batchSize: number, inputHeight: number, inputWidth: number,
filterHeight: number, filterWidth: number, padTop: number, padRight: number,
padBottom: number, padLeft: number, strideHeight: number,
strideWidth: number, channels: number, outId: number) => void;

function setup(backend: BackendWasm) {
wasmAvgPool = backend.wasm.cwrap('AvgPool', null /* void */, [
'number', // xId
'number', // batchSize
'number', // inputHeight
'number', // inputWidth
'number', // filterHeight
'number', // filterWidth
'number', // padTop
'number', // padRight
'number', // padBottom
'number', // padLeft
'number', // strideHeight
'number', // strideWidth
'number', // channels
'number', // outId
]);
}

function avgPool(args: {
inputs: AvgPoolInputs,
backend: BackendWasm,
attrs: backend_util.Conv2DInfo
}) {
const {inputs, attrs, backend} = args;
const convInfo = attrs;

const {x} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;

const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const padTop = convInfo.padInfo.top;
const padRight = convInfo.padInfo.right;
const padBottom = convInfo.padInfo.bottom;
const padLeft = convInfo.padInfo.left;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const channels = convInfo.inChannels;

if (convInfo.dataFormat !== 'channelsLast') {
throw new Error(
`wasm backend does not support dataFormat:'` +
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
}

if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) {
throw new Error(
`was backend only supports average pooling with dilation = [1, 1], ` +
`got [${convInfo.dilationHeight}, ${convInfo.dilationWidth}].`);
}

const out = backend.makeOutput(convInfo.outShape, 'float32');
const outId = backend.dataIdMap.get(out.dataId).id;

wasmAvgPool(
xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth,
padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, channels,
outId);
return out;
}

registerKernel({
kernelName: 'AvgPool',
backendName: 'wasm',
setupFunc: setup,
kernelFunc: avgPool as {} as KernelFunc
});
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/kernels/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
// the contents of this file and import only the kernels that are needed.
import './Abs';
import './Add';
import './AvgPool';
import './AddN';
import './BatchMatMul';
import './Cast';
Expand Down
Loading