diff --git a/tfjs-backend-wasm/scripts/cpplint.js b/tfjs-backend-wasm/scripts/cpplint.js index 1a2cb633880..4b5c040ea9e 100755 --- a/tfjs-backend-wasm/scripts/cpplint.js +++ b/tfjs-backend-wasm/scripts/cpplint.js @@ -20,13 +20,33 @@ const fs = require('fs'); const CC_FILEPATH = 'src/cc'; -const result = shell.find('src/cc').filter( +let python2Cmd; + +const ignoreCode = true; +const commandOpts = null; + +let pythonVersion = exec('python --version', commandOpts, ignoreCode); +if(pythonVersion['stderr'].includes('Python 2')) { + python2Cmd = 'python'; +} else { + pythonVersion = exec('python2 --version', commandOpts, ignoreCode); + if(pythonVersion.code === 0) { + python2Cmd = 'python2'; + } +} + +if(python2Cmd != null) { + const result = shell.find('src/cc').filter( fileName => fileName.endsWith('.cc') || fileName.endsWith('.h')); -console.log(`C++ linting files:`); -console.log(result); + console.log(`C++ linting files:`); + console.log(result); -const cwd = process.cwd() + '/' + CC_FILEPATH; + const cwd = process.cwd() + '/' + CC_FILEPATH; + const filenameArgument = result.join(' '); -const filenameArgument = result.join(' '); -exec(`python2 tools/cpplint.py --root ${cwd} ${filenameArgument}`); + exec(`${python2Cmd} tools/cpplint.py --root ${cwd} ${filenameArgument}`); +} else { + console.warn('No python2.x version found - please install python2. ' + + 'cpplint.py only works correctly with python 2.x.'); +} diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 63713f3792d..aa1fa7fcc07 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -68,6 +68,7 @@ tfjs_cc_library( ":Abs", ":Add", ":BatchMatMul", + ":CropAndResize", ":Conv2D", ":Div", ":Mul", @@ -136,6 +137,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "CropAndResize", + srcs = ["kernels/CropAndResize.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Conv2D", srcs = ["kernels/Conv2D.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/CropAndResize.cc b/tfjs-backend-wasm/src/cc/kernels/CropAndResize.cc new file mode 100644 index 00000000000..de782429d0a --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/CropAndResize.cc @@ -0,0 +1,257 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include +#include "src/cc/backend.h" + +#include "src/cc/util.h" + +// Must match enum in CropAndResize.ts +enum InterpolationMethod { + BILINEAR = 0, + NEAREST = 1, +}; + +namespace { +template +void interpolate_bilinear(T* out_buf_ptr, const T* images_buf, + std::vector images_strides, int crop_width, + int image_width, int image_width_m1, int num_channels, + float extrapolation_value, int box_ind, float y_ind, + float width_scale, float x1, float x2) { + float top_ind = floor(y_ind); + float bottom_ind = ceil(y_ind); + float y_lerp = y_ind - top_ind; + + for (int x = 0; x < crop_width; ++x) { + float x_ind = (crop_width > 1) ? x1 * image_width_m1 + x * width_scale + : 0.5 * (x1 + x2) * image_width_m1; + + if (x_ind < 0 || x_ind > image_width - 1) { + for (int c = 0; c < num_channels; ++c) { + *out_buf_ptr = extrapolation_value; + out_buf_ptr++; + } + continue; + } + + float left_ind = floor(x_ind); + float right_ind = ceil(x_ind); + float x_lerp = x_ind - left_ind; + + for (int c = 0; c < num_channels; ++c) { + int ind = c + left_ind * images_strides[2] + top_ind * images_strides[1] + + box_ind; + const float top_left = images_buf[ind]; + + ind = c + right_ind * images_strides[2] + top_ind * images_strides[1] + + box_ind; + + const float top_right = images_buf[ind]; + + ind = c + left_ind * images_strides[2] + bottom_ind * images_strides[1] + + box_ind; + + const float bottom_left = images_buf[ind]; + + ind = c + right_ind * images_strides[2] + bottom_ind * images_strides[1] + + box_ind; + + const float bottom_right = images_buf[ind]; + + const float top = top_left + (top_right - top_left) * x_lerp; + const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; + + *out_buf_ptr = top + ((bottom - top) * y_lerp); + out_buf_ptr++; + } + } +} + +template +void interpolate_nearest(T* out_buf_ptr, const T* images_buf, + std::vector images_strides, int crop_width, + int image_width, int image_width_m1, int num_channels, + float extrapolation_value, int box_ind, float y_ind, + float width_scale, float x1, float x2) { + for (int x = 0; x < crop_width; ++x) { + const float x_ind = (crop_width > 1) ? x1 * image_width_m1 + x * width_scale + : 0.5 * (x1 + x2) * image_width_m1; + + if (x_ind < 0 || x_ind > image_width - 1) { + for (int c = 0; c < num_channels; ++c) { + *out_buf_ptr = extrapolation_value; + out_buf_ptr++; + } + continue; + } + + float closest_x = round(x_ind); + float closest_y = round(y_ind); + for (int c = 0; c < num_channels; ++c) { + const int in_ind = c + closest_x * images_strides[2] + + closest_y * images_strides[1] + box_ind; + *out_buf_ptr = images_buf[in_ind]; + out_buf_ptr++; + } + } +} + +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void CropAndResize(int images_id, int boxes_id, int box_ind_id, int num_boxes, + int* images_shape_ptr, int crop_height, int crop_width, + InterpolationMethod method, float extrapolation_value, + int out_id) { + const int images_shape_length = 4; + const std::vector& images_shape = std::vector( + images_shape_ptr, images_shape_ptr + images_shape_length); + const auto images_strides = util::compute_strides(images_shape); + + const std::vector& output_shape = {num_boxes, crop_height, crop_width, + images_shape[3]}; + const auto output_strides = util::compute_strides(output_shape); + + auto& images_info = backend::get_tensor_info(images_id); + auto& boxes_info = backend::get_tensor_info(boxes_id); + auto& box_ind_info = backend::get_tensor_info(box_ind_id); + auto& out_info = backend::get_tensor_info_out(out_id); + + const float* images_buf = images_info.f32(); + const int images_size = images_info.size; + + const float* boxes_buf = boxes_info.f32(); + const int boxes_size = boxes_info.size; + + const int* box_ind_buf = box_ind_info.i32(); + const int box_ind_size = box_ind_info.size; + + float* out_buf = out_info.f32_write(); + const int out_size = out_info.size; + + const int batch = images_shape[0]; + const int image_height = images_shape[1]; + const int image_width = images_shape[2]; + const int num_channels = images_shape[3]; + + const int image_height_m1 = image_height - 1; + const int image_width_m1 = image_width - 1; + + for (int b = 0; b < num_boxes; ++b) { + const float y1 = *boxes_buf; + boxes_buf++; + const float x1 = *boxes_buf; + boxes_buf++; + const float y2 = *boxes_buf; + boxes_buf++; + const float x2 = *boxes_buf; + boxes_buf++; + + if (*box_ind_buf >= batch) { + continue; + } + + const int box_ind = *box_ind_buf * images_strides[0]; + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * image_height_m1 / (crop_height - 1) : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * image_width_m1 / (crop_width - 1) : 0; + + const bool crop_size_eq_box_size = + crop_width == 1 + (x2 - x1) * image_width_m1; + bool requires_interpolation = false; + if (method == InterpolationMethod::BILINEAR) { + const float y_lerp_factor = crop_height > 1 + ? y1 * image_height + height_scale + : 0.5 * (y1 + y2) * image_height_m1; + + if (y_lerp_factor - floor(y_lerp_factor) != 0.0) { + requires_interpolation = true; + } else { + const float x_lerp_factor = crop_width > 1 + ? x1 * image_width_m1 + width_scale + : 0.5 * (x1 + x2) * image_width_m1; + + if (x_lerp_factor - floor(x_lerp_factor) != 0.0) { + requires_interpolation = true; + } + } + } + + const bool should_memcpy = x2 > x1 && x1 >= 0 && + crop_size_eq_box_size == true && + requires_interpolation == false; + + for (int y = 0; y < crop_height; ++y) { + const float y_ind = (crop_height > 1) + ? y1 * image_height_m1 + y * height_scale + : 0.5 * (y1 + y2) * image_height_m1; + + float* out_buf_ptr = + out_buf + y * output_strides[1] + b * output_strides[0]; + + if (y_ind < 0 || y_ind > image_height - 1) { + for (int x = 0; x < crop_width; ++x) { + for (int c = 0; c < num_channels; ++c) { + *out_buf_ptr = extrapolation_value; + out_buf_ptr++; + } + } + continue; + } + + if (should_memcpy) { + int y_ind_int = y_ind; + images_buf += (y_ind_int * images_strides[1] + box_ind); + + memcpy(out_buf_ptr, images_buf, sizeof(float) * crop_width); + continue; + } + + if (method == InterpolationMethod::BILINEAR) { + interpolate_bilinear(out_buf_ptr, images_buf, images_strides, + crop_width, image_width, image_width_m1, + num_channels, extrapolation_value, box_ind, y_ind, + width_scale, x1, x2); + + } else { + interpolate_nearest(out_buf_ptr, images_buf, images_strides, crop_width, + image_width, image_width_m1, num_channels, + extrapolation_value, box_ind, y_ind, width_scale, + x1, x2); + } + } + + box_ind_buf++; + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/CropAndResize.ts b/tfjs-backend-wasm/src/kernels/CropAndResize.ts new file mode 100644 index 00000000000..b6cb223b3d2 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/CropAndResize.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface CropAndResizeInputs extends NamedTensorInfoMap { + images: TensorInfo; + boxes: TensorInfo; + boxInd: TensorInfo; +} + +interface CropAndResizeAttrs extends NamedAttrMap { + method: keyof InterpolationMethod; + extrapolationValue: number; + cropSize: [number, number]; +} + +// Must match enum in CropAndResize.cc +enum InterpolationMethod { + bilinear = 0, + nearest = 1 +} + +let wasmCropAndResize: ( + imagesId: number, boxesId: number, boxIndId: number, numBoxes: number, + imagesShape: Uint8Array, cropHeight: number, cropWidth: number, + method: number, extrapolationValue: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmCropAndResize = backend.wasm.cwrap('CropAndResize', null /*void*/, [ + 'number', // imagesId + 'number', // boxesId + 'number', // boxIndId + 'number', // numBoxes + 'array', // images shape + 'number', // cropHeight + 'number', // cropWidth + 'number', // method + 'number', // extrapolation value + 'number' // out id + ]); +} + +function cropAndResize(args: { + backend: BackendWasm, + inputs: CropAndResizeInputs, + attrs: CropAndResizeAttrs +}): TensorInfo { + const {backend, inputs, attrs} = args; + const {method, extrapolationValue, cropSize} = attrs; + const {images, boxes, boxInd} = inputs; + + const numBoxes = boxes.shape[0]; + + const [cropHeight, cropWidth] = cropSize as [number, number]; + const outShape = [numBoxes, cropHeight, cropWidth, images.shape[3]]; + + const imagesId = backend.dataIdMap.get(images.dataId).id; + const boxesId = backend.dataIdMap.get(boxes.dataId).id; + const boxIndId = backend.dataIdMap.get(boxInd.dataId).id; + + const out = backend.makeOutput(outShape, images.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + + const imagesShapeBytes = new Uint8Array(new Int32Array(images.shape).buffer); + + wasmCropAndResize( + imagesId, boxesId, boxIndId, numBoxes, imagesShapeBytes, cropHeight, + cropWidth, + InterpolationMethod[method as {} as keyof typeof InterpolationMethod], + extrapolationValue, outId); + return out; +} + +registerKernel({ + kernelName: 'CropAndResize', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: cropAndResize +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 0125b92195b..085baf4af3e 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -21,6 +21,7 @@ import './Abs'; import './Add'; import './BatchMatMul'; +import './CropAndResize'; import './FusedBatchNorm'; import './Cast'; import './Concat'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index cd48c5bda1a..986e0d20bba 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -36,6 +36,7 @@ const TEST_FILTERS: TestFilter[] = [ 'complex', // Complex numbers not supported yet ] }, + {include: 'cropAndResize'}, { include: 'matmul ', excludes: [ @@ -118,6 +119,7 @@ const TEST_FILTERS: TestFilter[] = [ 'broadcasting same rank Tensors different shape', // Broadcasting along // inner dims not // supported yet. + 'divNoNan' // divNoNan not yet implemented. ] }, { diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index c3014b5db28..26b15f7bc9d 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -302,7 +302,9 @@ function cropAndResize_( backend.cropAndResize( $image, $boxes, $boxInd, cropSize, method, extrapolationValue); - const res = ENGINE.runKernelFunc(forward, {$image, $boxes}); + const res = ENGINE.runKernelFunc( + forward, {images: $image, boxes: $boxes, boxInd: $boxInd}, null /* der */, + 'CropAndResize', {method, extrapolationValue, cropSize}); return res; }