Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Add CropAndResize kernel. #2307

Merged
merged 57 commits into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
f383c67
initial
annxingyuan Nov 1, 2019
b6f9f29
compile
annxingyuan Nov 1, 2019
25e3598
pass more args
annxingyuan Nov 1, 2019
34e0bfb
pass everything in
annxingyuan Nov 1, 2019
acb2a90
pass
annxingyuan Nov 1, 2019
951184d
begin iter
annxingyuan Nov 1, 2019
f956c53
outline
annxingyuan Nov 1, 2019
6a27ea8
send shapes
annxingyuan Nov 1, 2019
2bf8c87
revive
annxingyuan Nov 4, 2019
26b7b17
debug
annxingyuan Nov 4, 2019
38dc00f
Merge branch 'master' into wasm_cropresize
annxingyuan Nov 4, 2019
07d8326
properly pass shape
annxingyuan Nov 4, 2019
9f8c241
types
annxingyuan Nov 4, 2019
3e2cb39
test
annxingyuan Nov 4, 2019
c0f1f86
fix
annxingyuan Nov 4, 2019
1144921
tests pass
annxingyuan Nov 4, 2019
000b66d
Merge branch 'master' into wasm_cropresize
annxingyuan Nov 5, 2019
9be6125
update
annxingyuan Nov 5, 2019
635260c
clean
annxingyuan Nov 5, 2019
4eae6e6
clean
annxingyuan Nov 5, 2019
dfbb107
point
annxingyuan Nov 5, 2019
77ba34c
offset
annxingyuan Nov 5, 2019
f8766de
mege
annxingyuan Nov 6, 2019
c503dec
memcpy
annxingyuan Nov 6, 2019
cf0ee20
Merge branch 'master' into wasm_cropresize
annxingyuan Nov 6, 2019
e062bcb
remove print
annxingyuan Nov 6, 2019
7b38666
simplify
annxingyuan Nov 6, 2019
d64a6af
clean
annxingyuan Nov 6, 2019
c06f1af
clean
annxingyuan Nov 6, 2019
0fe6ac5
clean
annxingyuan Nov 6, 2019
0526ff6
clean
annxingyuan Nov 6, 2019
45fa3e5
simplify
annxingyuan Nov 6, 2019
9bf394a
clean
annxingyuan Nov 7, 2019
df0c784
clean
annxingyuan Nov 7, 2019
1aff51e
clean
annxingyuan Nov 7, 2019
1145164
clean
annxingyuan Nov 7, 2019
eed18eb
clean
annxingyuan Nov 7, 2019
2283964
Merge branch 'master' into wasm_cropresize
annxingyuan Nov 8, 2019
f87e45c
clean
annxingyuan Nov 8, 2019
6bd59ca
clean
annxingyuan Nov 8, 2019
f3cc5f4
clean
annxingyuan Nov 8, 2019
7a2651a
size
annxingyuan Nov 8, 2019
f142be1
clean
annxingyuan Nov 8, 2019
6e3cb60
use const
annxingyuan Nov 8, 2019
fab1a2b
use const
annxingyuan Nov 8, 2019
72dcca2
comment
annxingyuan Nov 8, 2019
56995f7
remove excludes
annxingyuan Nov 8, 2019
89ae1e1
save
annxingyuan Nov 8, 2019
06a6fb6
lint
annxingyuan Nov 8, 2019
5b4fb56
python
annxingyuan Nov 8, 2019
43b9b10
Merge branch 'master' into wasm_cropresize
annxingyuan Nov 8, 2019
b8f7bef
save
annxingyuan Nov 8, 2019
39bc17f
Merge branch 'master' into wasm_cropresize
annxingyuan Nov 11, 2019
500c01d
update
annxingyuan Nov 11, 2019
532ce25
add setup test
annxingyuan Nov 11, 2019
fec7996
lint
annxingyuan Nov 11, 2019
19d8fa9
save
annxingyuan Nov 11, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ tfjs_cc_library(
":Abs",
":Add",
":BatchMatMul",
":CropAndResize",
":Conv2D",
":Div",
":Mul",
Expand Down Expand Up @@ -136,6 +137,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "CropAndResize",
srcs = ["kernels/CropAndResize.cc"],
deps = [
":backend",
":util",
],
)

tfjs_cc_library(
name = "Conv2D",
srcs = ["kernels/Conv2D.cc"],
Expand Down
256 changes: 256 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/CropAndResize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <vector>

#include <cmath>
#include "src/cc/backend.h"

#include "src/cc/util.h"

// Must match enum in CropAndResize.ts
enum InterpolationMethod {
BILINEAR = 0,
NEAREST = 1,
};

namespace {
template <typename T>
void interpolate_bilinear(T* out_buf_ptr, const T* images_buf,
std::vector<int> images_strides, int crop_width,
int image_width, int image_width_m1, int num_channels,
float extrapolation_value, int box_ind, float y_ind,
float width_scale, float x1, float x2) {
float top_ind = floor(y_ind);
float bottom_ind = ceil(y_ind);
float y_lerp = y_ind - top_ind;

for (int x = 0; x < crop_width; ++x) {
float x_ind = (crop_width > 1) ? x1 * image_width_m1 + x * width_scale
: 0.5 * (x1 + x2) * image_width_m1;

if (x_ind < 0 || x_ind > image_width - 1) {
for (int c = 0; c < num_channels; ++c) {
*out_buf_ptr = extrapolation_value;
out_buf_ptr++;
}
continue;
}

float left_ind = floor(x_ind);
float right_ind = ceil(x_ind);
float x_lerp = x_ind - left_ind;

for (int c = 0; c < num_channels; ++c) {
int ind = c + left_ind * images_strides[2] + top_ind * images_strides[1] +
box_ind;
const float top_left = images_buf[ind];

ind = c + right_ind * images_strides[2] + top_ind * images_strides[1] +
box_ind;

const float top_right = images_buf[ind];

ind = c + left_ind * images_strides[2] + bottom_ind * images_strides[1] +
box_ind;

const float bottom_left = images_buf[ind];

ind = c + right_ind * images_strides[2] + bottom_ind * images_strides[1] +
box_ind;

const float bottom_right = images_buf[ind];

const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;

*out_buf_ptr = top + ((bottom - top) * y_lerp);
out_buf_ptr++;
}
}
}

template <typename T>
void interpolate_nearest(T* out_buf_ptr, const T* images_buf,
std::vector<int> images_strides, int crop_width,
int image_width, int image_width_m1, int num_channels,
float extrapolation_value, int box_ind, float y_ind,
float width_scale, float x1, float x2) {
for (int x = 0; x < crop_width; ++x) {
const float x_ind = (crop_width > 1) ? x1 * image_width_m1 + x * width_scale
: 0.5 * (x1 + x2) * image_width_m1;

if (x_ind < 0 || x_ind > image_width - 1) {
for (int c = 0; c < num_channels; ++c) {
*out_buf_ptr = extrapolation_value;
out_buf_ptr++;
}
continue;
}

float closest_x = round(x_ind);
float closest_y = round(y_ind);
for (int c = 0; c < num_channels; ++c) {
const int in_ind = c + closest_x * images_strides[2] +
closest_y * images_strides[1] + box_ind;
*out_buf_ptr = images_buf[in_ind];
out_buf_ptr++;
}
}
}

} // namespace

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void CropAndResize(int images_id, int boxes_id, int box_ind_id, int num_boxes,
int* images_shape_ptr, int crop_height, int crop_width,
InterpolationMethod method, float extrapolation_value,
int out_id) {
const int images_shape_length = 4;
const std::vector<int>& images_shape = std::vector<int>(
images_shape_ptr, images_shape_ptr + images_shape_length);
const auto images_strides = util::compute_strides(images_shape);

const std::vector<int>& output_shape = {num_boxes, crop_height, crop_width,
images_shape[3]};
const auto output_strides = util::compute_strides(output_shape);

auto& images_info = backend::get_tensor_info(images_id);
auto& boxes_info = backend::get_tensor_info(boxes_id);
auto& box_ind_info = backend::get_tensor_info(box_ind_id);
auto& out_info = backend::get_tensor_info_out(out_id);

const float* images_buf = images_info.f32();
const int images_size = images_info.size;

const float* boxes_buf = boxes_info.f32();
const int boxes_size = boxes_info.size;

const int* box_ind_buf = box_ind_info.i32();
const int box_ind_size = box_ind_info.size;

float* out_buf = out_info.f32_write();
const int out_size = out_info.size;

const int batch = images_shape[0];
const int image_height = images_shape[1];
const int image_width = images_shape[2];
const int num_channels = images_shape[3];

const int image_height_m1 = image_height - 1;
const int image_width_m1 = image_width - 1;

for (int b = 0; b < num_boxes; ++b) {
const float y1 = *boxes_buf;
boxes_buf++;
const float x1 = *boxes_buf;
boxes_buf++;
const float y2 = *boxes_buf;
boxes_buf++;
const float x2 = *boxes_buf;
boxes_buf++;

if (*box_ind_buf >= batch) {
continue;
}

const int box_ind = *box_ind_buf * images_strides[0];

const float height_scale =
(crop_height > 1) ? (y2 - y1) * image_height_m1 / (crop_height - 1) : 0;
const float width_scale =
(crop_width > 1) ? (x2 - x1) * image_width_m1 / (crop_width - 1) : 0;

const bool crop_size_eq_box_size =
crop_width == 1 + (x2 - x1) * image_width_m1;
bool requires_interpolation = false;
if (method == InterpolationMethod::BILINEAR) {
const float y_lerp_factor = crop_height > 1
? y1 * image_height + height_scale
: 0.5 * (y1 + y2) * image_height_m1;

if (y_lerp_factor - long(y_lerp_factor) != 0.0) {
requires_interpolation = true;
} else {
const float x_lerp_factor = crop_width > 1
? x1 * image_width_m1 + width_scale
: 0.5 * (x1 + x2) * image_width_m1;

if (x_lerp_factor - long(x_lerp_factor) != 0.0) {
requires_interpolation = true;
}
}
}

const bool should_memcpy = x2 > x1 && x1 >= 0 &&
crop_size_eq_box_size == true &&
requires_interpolation == false;

for (int y = 0; y < crop_height; ++y) {
const float y_ind = (crop_height > 1)
? y1 * image_height_m1 + y * height_scale
: 0.5 * (y1 + y2) * image_height_m1;

float* out_buf_ptr =
out_buf + y * output_strides[1] + b * output_strides[0];

if (y_ind < 0 || y_ind > image_height - 1) {
for (int x = 0; x < crop_width; ++x) {
for (int c = 0; c < num_channels; ++c) {
*out_buf_ptr = extrapolation_value;
out_buf_ptr++;
}
}
continue;
}

if (should_memcpy) {
images_buf += (int(y_ind) * images_strides[1] + box_ind);

memcpy(out_buf_ptr, images_buf, sizeof(float) * crop_width);
continue;
}

if (method == InterpolationMethod::BILINEAR) {
interpolate_bilinear(out_buf_ptr, images_buf, images_strides,
crop_width, image_width, image_width_m1,
num_channels, extrapolation_value, box_ind, y_ind,
width_scale, x1, x2);

} else {
interpolate_nearest(out_buf_ptr, images_buf, images_strides, crop_width,
image_width, image_width_m1, num_channels,
extrapolation_value, box_ind, y_ind, width_scale,
x1, x2);
}
}

box_ind_buf++;
}
}

} // extern "C"
} // namespace wasm
} // namespace tfjs
96 changes: 96 additions & 0 deletions tfjs-backend-wasm/src/kernels/CropAndResize.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface CropAndResizeInputs extends NamedTensorInfoMap {
images: TensorInfo;
boxes: TensorInfo;
boxInd: TensorInfo;
}

interface CropAndResizeAttrs extends NamedAttrMap {
method: keyof InterpolationMethod;
extrapolationValue: number;
cropSize: [number, number];
}

// Must match enum in CropAndResize.cc
enum InterpolationMethod {
bilinear = 0,
nearest = 1
}

let wasmCropAndResize: (
imagesId: number, boxesId: number, boxIndId: number, numBoxes: number,
imagesShape: Uint8Array, cropHeight: number, cropWidth: number,
method: number, extrapolationValue: number, outId: number) => void;

function setup(backend: BackendWasm): void {
wasmCropAndResize = backend.wasm.cwrap('CropAndResize', null /*void*/, [
'number', // imagesId
'number', // boxesId
'number', // boxIndId
'number', // numBoxes
'array', // images shape
'number', // cropHeight
'number', // cropWidth
'number', // method
'number', // extrapolation value
'number' // out id
]);
}

function cropAndResize(args: {
backend: BackendWasm,
inputs: CropAndResizeInputs,
attrs: CropAndResizeAttrs
}): TensorInfo {
const {backend, inputs, attrs} = args;
const {method, extrapolationValue, cropSize} = attrs;
const {images, boxes, boxInd} = inputs;

const numBoxes = boxes.shape[0];

const [cropHeight, cropWidth] = cropSize as [number, number];
const outShape = [numBoxes, cropHeight, cropWidth, images.shape[3]];

const imagesId = backend.dataIdMap.get(images.dataId).id;
const boxesId = backend.dataIdMap.get(boxes.dataId).id;
const boxIndId = backend.dataIdMap.get(boxInd.dataId).id;

const out = backend.makeOutput(outShape, images.dtype);
const outId = backend.dataIdMap.get(out.dataId).id;

const imagesShapeBytes = new Uint8Array(new Int32Array(images.shape).buffer);

wasmCropAndResize(
imagesId, boxesId, boxIndId, numBoxes, imagesShapeBytes, cropHeight,
cropWidth,
InterpolationMethod[method as {} as keyof typeof InterpolationMethod],
extrapolationValue, outId);
return out;
}

registerKernel({
kernelName: 'CropAndResize',
backendName: 'wasm',
setupFunc: setup,
kernelFunc: cropAndResize
});
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/kernels/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import './Abs';
import './Add';
import './BatchMatMul';
import './CropAndResize';
import './FusedBatchNorm';
import './Cast';
import './Concat';
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const TEST_FILTERS: TestFilter[] = [
'complex', // Complex numbers not supported yet
]
},
{include: 'cropAndResize'},
{
include: 'matmul ',
excludes: [
Expand Down
Loading