diff --git a/tfjs-backend-wasm/karma.conf.js b/tfjs-backend-wasm/karma.conf.js index ab533eb958e..c8f3c315bd1 100644 --- a/tfjs-backend-wasm/karma.conf.js +++ b/tfjs-backend-wasm/karma.conf.js @@ -22,7 +22,10 @@ const karmaTypescriptConfig = { sourceMap: true, // Ignore the import of the `worker_threads` package used in a core test // meant to run in node. - exclude: ['worker_threads'] + exclude: ['worker_threads'], + // worker_node_test in tfjs-core contains a conditional require statement + // that confuses the bundler of karma-typescript. + ignore: ['./worker_node_test'] }, // Disable coverage reports and instrumentation by default for tests coverageOptions: {instrumentation: false}, diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index d0606dca135..f8110196f88 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -78,6 +78,7 @@ tfjs_cc_library( deps = [ ":Abs", ":Add", + ":AvgPool", ":AddN", ":BatchMatMul", ":MaxPool", @@ -109,6 +110,16 @@ tfjs_cc_library( ] ) +tfjs_cc_library( + name = "AvgPool", + srcs = ["kernels/AvgPool.cc"], + hdrs = ["kernels/AvgPool.h"], + deps = [ + ":backend", + ":util" + ] +) + tfjs_cc_library( name = "FusedBatchNorm", srcs = ["kernels/FusedBatchNorm.cc"], @@ -348,6 +359,14 @@ tfjs_unit_test( ] ) +tfjs_unit_test( + name = "AvgPool_test", + srcs = ["kernels/AvgPool_test.cc"], + deps = [ + ":AvgPool" + ] +) + tfjs_unit_test( name = "MaxPool_test", srcs = ["kernels/MaxPool_test.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/AvgPool.cc b/tfjs-backend-wasm/src/cc/kernels/AvgPool.cc new file mode 100644 index 00000000000..8f50314b66e --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/AvgPool.cc @@ -0,0 +1,107 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include +#include +#include +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/kernels/AvgPool.h" +#include "src/cc/util.h" + +namespace { +typedef std::array OperatorCacheKey; + +std::map operator_cache; +} // namespace + +namespace tfjs { +namespace wasm { +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void AvgPool(const int x_id, const int batch_size, const int input_height, + const int input_width, const int filter_height, + const int filter_width, int pad_top, int pad_right, int pad_bottom, + int pad_left, const int stride_height, const int stride_width, + const int channels, const int out_id) { + auto& x_info = backend::get_tensor_info(x_id); + auto& out_info = backend::get_tensor_info(out_id); + + const float* x_buf = reinterpret_cast(x_info.memory_offset); + float* out_buf = reinterpret_cast(out_info.memory_offset); + + xnn_operator_t avg_pool_op = nullptr; + + const int flags = 0; + const int input_pixel_stride = channels; + const int output_pixel_stride = channels; + + OperatorCacheKey cache_key = { + pad_top, pad_right, pad_bottom, pad_left, + filter_height, filter_width, stride_height, stride_width, + channels, input_pixel_stride, output_pixel_stride, flags}; + + auto operator_cache_idx = operator_cache.find(cache_key); + + if (operator_cache_idx == operator_cache.end()) { + float output_min = -std::numeric_limits::infinity(); + float output_max = std::numeric_limits::infinity(); + + xnn_status status = xnn_create_average_pooling2d_nhwc_f32( + pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width, + stride_height, stride_width, channels, input_pixel_stride, + output_pixel_stride, output_min, output_max, flags, &avg_pool_op); + + if (status != xnn_status_success) { + util::warn( + "XNN status for xnn_create_average_pooling2d_nhwc_f32 is not " + "successful. ", + "Got status %d. Use -c dbg to see XNN logs.", status); + return; + } + + operator_cache.emplace(cache_key, avg_pool_op); + + tfjs::backend::xnn_operator_count++; + } else { + avg_pool_op = operator_cache_idx->second; + } + + xnn_status status = xnn_setup_average_pooling2d_nhwc_f32( + avg_pool_op, batch_size, input_height, input_width, x_buf, out_buf, + nullptr /* thread pool */); + if (status != xnn_status_success) { + util::warn( + "XNN status for xnn_setup_average_pooling2d_nhwc_f32 is not " + "successful. " + "Got status %d. Use -c dbg to see XNN logs.", + status); + return; + } + + xnn_run_operator(avg_pool_op, nullptr /* thread pool */); +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/AvgPool.h b/tfjs-backend-wasm/src/cc/kernels/AvgPool.h new file mode 100644 index 00000000000..44d138d5fa6 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/AvgPool.h @@ -0,0 +1,32 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifndef KERNELS_AVGPOOL_H_ +#define KERNELS_AVGPOOL_H_ + +namespace tfjs { + +namespace wasm { +extern "C" { +void AvgPool(const int x_id, const int batch_size, const int input_height, + const int input_width, const int filter_height, + const int filter_width, int pad_top, int pad_right, int pad_bottom, + int pad_left, const int stride_height, const int stride_width, + const int channels, const int out_id); +} + +} // namespace wasm +} // namespace tfjs + +#endif // KERNELS_AVGPOOL_H_ diff --git a/tfjs-backend-wasm/src/cc/kernels/AvgPool_test.cc b/tfjs-backend-wasm/src/cc/kernels/AvgPool_test.cc new file mode 100644 index 00000000000..626d0d5ea00 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/AvgPool_test.cc @@ -0,0 +1,78 @@ + +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#include + +#include "src/cc/backend.h" +#include "src/cc/kernels/AvgPool.h" + +TEST(MAXPOOL, xnn_operator_lifetime) { + tfjs::wasm::init(); + + ASSERT_EQ(0, tfjs::backend::num_tensors()); + + const int x0_id = 0; + const int x1_id = 1; + const int size = 9; + float x_values[size] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + + const int out_id = 2; + const int out_size = 9; + float out_values[out_size] = {}; + + tfjs::wasm::register_tensor(x0_id, size, x_values); + tfjs::wasm::register_tensor(x1_id, size, x_values); + tfjs::wasm::register_tensor(out_id, out_size, out_values); + + ASSERT_EQ(3, tfjs::backend::num_tensors()); + ASSERT_EQ(0, tfjs::backend::xnn_operator_count); + + // One xnn_operator should be created for first call to avgPool. + const int batch_size = 1; + const int input_height = 3; + const int input_width = 3; + const int filter_height = 2; + const int filter_width = 2; + const int pad_top = 0; + const int pad_right = 1; + const int pad_bottom = 1; + const int pad_left = 0; + const int stride_height = 1; + const int stride_width = 1; + const int channels = 1; + tfjs::wasm::AvgPool(x0_id, batch_size, input_height, input_width, + filter_height, filter_width, pad_top, pad_right, + pad_bottom, pad_left, stride_height, stride_width, + channels, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // No new xnn_operators should be created for the second call to avgPool with + // the same arguments. + tfjs::wasm::AvgPool(x0_id, batch_size, input_height, input_width, + filter_height, filter_width, pad_top, pad_right, + pad_bottom, pad_left, stride_height, stride_width, + channels, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // One new xnn_operator should be created for the next call to avgPool with + // 'valid' padding. + tfjs::wasm::AvgPool(x0_id, batch_size, input_height, input_width, + filter_height, filter_width, pad_top, 0 /* pad_right */, + 0 /* pad_bottom */, pad_left, stride_height, stride_width, + channels, out_id); + ASSERT_EQ(2, tfjs::backend::xnn_operator_count); + + tfjs::wasm::dispose(); +} diff --git a/tfjs-backend-wasm/src/kernels/AvgPool.ts b/tfjs-backend-wasm/src/kernels/AvgPool.ts new file mode 100644 index 00000000000..3b2d8ea2204 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/AvgPool.ts @@ -0,0 +1,100 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface AvgPoolInputs extends NamedTensorInfoMap { + x: TensorInfo; + filter: TensorInfo; +} + +let wasmAvgPool: ( + xId: number, batchSize: number, inputHeight: number, inputWidth: number, + filterHeight: number, filterWidth: number, padTop: number, padRight: number, + padBottom: number, padLeft: number, strideHeight: number, + strideWidth: number, channels: number, outId: number) => void; + +function setup(backend: BackendWasm) { + wasmAvgPool = backend.wasm.cwrap('AvgPool', null /* void */, [ + 'number', // xId + 'number', // batchSize + 'number', // inputHeight + 'number', // inputWidth + 'number', // filterHeight + 'number', // filterWidth + 'number', // padTop + 'number', // padRight + 'number', // padBottom + 'number', // padLeft + 'number', // strideHeight + 'number', // strideWidth + 'number', // channels + 'number', // outId + ]); +} + +function avgPool(args: { + inputs: AvgPoolInputs, + backend: BackendWasm, + attrs: backend_util.Conv2DInfo +}) { + const {inputs, attrs, backend} = args; + const convInfo = attrs; + + const {x} = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padTop = convInfo.padInfo.top; + const padRight = convInfo.padInfo.right; + const padBottom = convInfo.padInfo.bottom; + const padLeft = convInfo.padInfo.left; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const channels = convInfo.inChannels; + + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error( + `wasm backend does not support dataFormat:'` + + `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + } + + if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) { + throw new Error( + `was backend only supports average pooling with dilation = [1, 1], ` + + `got [${convInfo.dilationHeight}, ${convInfo.dilationWidth}].`); + } + + const out = backend.makeOutput(convInfo.outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + + wasmAvgPool( + xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, + padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, channels, + outId); + return out; +} + +registerKernel({ + kernelName: 'AvgPool', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: avgPool as {} as KernelFunc +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 504befba613..03eeae91e17 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -20,6 +20,7 @@ // the contents of this file and import only the kernels that are needed. import './Abs'; import './Add'; +import './AvgPool'; import './AddN'; import './BatchMatMul'; import './Cast'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 63b4616f9e2..f670a4cecd6 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -36,12 +36,18 @@ const TEST_FILTERS: TestFilter[] = [ 'complex', // Complex numbers not supported yet ] }, + { + include: 'avgPool', + excludes: [ + 'gradient', // Not yet implemented. + 'avgPool3d', // Not yet implemented. + ] + }, { include: 'maxPool', excludes: [ - 'f=[1,1]', // XNN does not support filter height and width of 1. - 'maxPoolBackprop', // Not yet implemented. - 'maxPool3d', // Not yet implemented. + 'maxPoolBackprop', // Not yet implemented. + 'maxPool3d', // Not yet implemented. 'maxPool3dBackprop', // Not yet implemented. 'ignores NaNs' // Actual != expected. ] diff --git a/tfjs-core/src/ops/pool.ts b/tfjs-core/src/ops/pool.ts index 7993e7fc38d..507a7981caf 100644 --- a/tfjs-core/src/ops/pool.ts +++ b/tfjs-core/src/ops/pool.ts @@ -81,8 +81,7 @@ function maxPoolImpl_( const convInfo = conv_util.computePool2DInfo( x4D.shape, filterSize, strides, dilations, pad, dimRoundingMode); if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && - util.arraysEqual(convInfo.inShape, convInfo.outShape) && - convInfo.padInfo.type === 'VALID') { + util.arraysEqual(convInfo.inShape, convInfo.outShape)) { return $x.clone(); } @@ -192,14 +191,20 @@ function avgPoolImpl_( const convInfo = conv_util.computePool2DInfo( x4D.shape, filterSize, strides, dilations, pad, dimRoundingMode); + if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && + util.arraysEqual(convInfo.inShape, convInfo.outShape)) { + return $x.clone(); + } const grad = (dy: Tensor4D) => { return { x: () => avgPoolBackprop(dy, x4D, filterSize, strides, dilations, pad) }; }; + let res = ENGINE.runKernelFunc( - backend => backend.avgPool(x4D, convInfo), {x: x4D}, grad); + backend => backend.avgPool(x4D, convInfo), {x: x4D}, grad, 'AvgPool', + convInfo); res = res.cast($x.dtype); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;