Skip to content

Commit

Permalink
Add MaxPool3DGrad
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc committed Jan 23, 2023
1 parent f75c8bf commit 3f25c9d
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 41 deletions.
10 changes: 10 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ tfjs_cc_library(
":Max",
":MaxPool",
":MaxPool3D",
":MaxPool3DGrad",
":Maximum",
":Min",
":Minimum",
Expand Down Expand Up @@ -985,6 +986,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "MaxPool3DGrad",
srcs = ["kernels/MaxPool3DGrad.cc"],
deps = [
":backend",
":pool3d_impl",
],
)

tfjs_cc_library(
name = "Min",
srcs = ["kernels/Min.cc"],
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/cc/kernels/AvgPool3D.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void AvgPool3D(int x_id, int out_id, int batch_size, int channel_size,
return {0.0, 0};
},
/*filter_apply=*/
[](std::pair<float, int>& data, const float& val) {
[](std::pair<float, int>& data, int, const float& val) {
data.first += val;
++data.second;
},
Expand Down
54 changes: 26 additions & 28 deletions tfjs-backend-wasm/src/cc/kernels/AvgPool3DGrad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,38 +48,36 @@ void AvgPool3DGrad(int dy_id, int dx_id, int batch_size, int channel_size,
const TensorInfo& dy_info = backend::get_tensor_info(dy_id);
TensorInfo& dx_info = backend::get_tensor_info_out(dx_id);

pad_front = effective_filter_depth - 1 - pad_front;
pad_top = effective_filter_height - 1 - pad_top;
pad_left = effective_filter_width - 1 - pad_left;
NDHWCPool3DInfo pool3d_info{
.batch_size = batch_size,
.channel_size = channel_size,
.in_depth = in_depth,
.in_height = in_height,
.in_width = in_width,
.out_depth = out_depth,
.out_height = out_height,
.out_width = out_width,
.stride_depth = stride_depth,
.stride_height = stride_height,
.stride_width = stride_width,
.dilation_depth = dilation_depth,
.dilation_height = dilation_height,
.dilation_width = dilation_width,
.effective_filter_depth = effective_filter_depth,
.effective_filter_height = effective_filter_height,
.effective_filter_width = effective_filter_width,
.pad_front = pad_front,
.pad_top = pad_top,
.pad_left = pad_left,
};
NDHWCPool3DGradImpl(
dy_info.f32(), dx_info.f32_write(), pool3d_info,
dy_info.f32(), dx_info.f32_write(),
NDHWCPool3DInfo{
.batch_size = batch_size,
.channel_size = channel_size,
.in_depth = in_depth,
.in_height = in_height,
.in_width = in_width,
.out_depth = out_depth,
.out_height = out_height,
.out_width = out_width,
.stride_depth = stride_depth,
.stride_height = stride_height,
.stride_width = stride_width,
.dilation_depth = dilation_depth,
.dilation_height = dilation_height,
.dilation_width = dilation_width,
.effective_filter_depth = effective_filter_depth,
.effective_filter_height = effective_filter_height,
.effective_filter_width = effective_filter_width,
.pad_front = effective_filter_depth - 1 - pad_front,
.pad_top = effective_filter_height - 1 - pad_top,
.pad_left = effective_filter_width - 1 - pad_left,
},
/*pixel_mask=*/
[avg_multiplier = 1.0f / (static_cast<float>(filter_depth) *
static_cast<float>(filter_height) *
static_cast<float>(filter_width))](
int, int, int, int, int, int) { return avg_multiplier; });
static_cast<float>(filter_width))](int, int) {
return avg_multiplier;
});
}

} // extern "C"
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/cc/kernels/MaxPool3D.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void MaxPool3D(int x_id, int out_id, int batch_size, int channel_size,
/*filter_init=*/
[]() -> float { return std::numeric_limits<float>::min(); },
/*filter_apply=*/
[](float& data, const float& val) { data = std::max(data, val); },
[](float& data, int, const float& val) { data = std::max(data, val); },
/*filter_aggregate=*/
[](const float& data) { return data; });
}
Expand Down
105 changes: 105 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/MaxPool3DGrad.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <algorithm>
#include <limits>

#include "tfjs-backend-wasm/src/cc/backend.h"
#include "tfjs-backend-wasm/src/cc/pool3d_impl.h"

namespace tfjs::wasm {

// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

// REQUIRES:
// - Tensor `x`, `dx` and `dy` must have dtype float32 (checked in tfjs-core)
// - Tensor `x`, `dx` and `dy` must have data format 'NDHWC' (checked in
// tfjs-core)
void MaxPool3DGrad(int x_id, int dy_id, int dx_id, int batch_size,
int channel_size, int in_depth, int in_height, int in_width,
int out_depth, int out_height, int out_width,
int stride_depth, int stride_height, int stride_width,
int dilation_depth, int dilation_height, int dilation_width,
int effective_filter_depth, int effective_filter_height,
int effective_filter_width, int pad_front, int pad_top,
int pad_left, int filter_depth) {
const TensorInfo& x_info = backend::get_tensor_info(x_id);
const TensorInfo& dy_info = backend::get_tensor_info(dy_id);
TensorInfo& dx_info = backend::get_tensor_info_out(dx_id);
NDHWCPool3DInfo pool3d_info{
.batch_size = batch_size,
.channel_size = channel_size,
.in_depth = in_depth,
.in_height = in_height,
.in_width = in_width,
.out_depth = out_depth,
.out_height = out_height,
.out_width = out_width,
.stride_depth = stride_depth,
.stride_height = stride_height,
.stride_width = stride_width,
.dilation_depth = dilation_depth,
.dilation_height = dilation_height,
.dilation_width = dilation_width,
.effective_filter_depth = effective_filter_depth,
.effective_filter_height = effective_filter_height,
.effective_filter_width = effective_filter_width,
.pad_front = pad_front,
.pad_top = pad_top,
.pad_left = pad_left,
};

int* max_positions = new int[pool3d_info.out_size()];
NDHWCPool3DImpl</*IN=*/float, /*OUT=*/int>(
x_info.f32(), max_positions, pool3d_info,
/*filter_init=*/
[]() -> std::pair<float, int> {
return {std::numeric_limits<float>::min(), 0};
},
/*filter_apply=*/
[](std::pair<float, int>& data, int x_offset, const float& x_val) {
if (x_val >= data.first) {
data = {x_val, x_offset};
}
},
/*filter_aggregate=*/
[](const std::pair<float, int>& data) { return data.second; });

pool3d_info.pad_front = effective_filter_depth - 1 - pad_front;
pool3d_info.pad_top = effective_filter_height - 1 - pad_top;
pool3d_info.pad_left = effective_filter_width - 1 - pad_left;
NDHWCPool3DGradImpl(
dy_info.f32(), dx_info.f32_write(), pool3d_info,
/*pixel_mask=*/
[&max_positions](int dy_offset, int dx_offset) {
return static_cast<float>(dx_offset == max_positions[dy_offset]);
});

delete[] max_positions;
}

} // extern "C"
} // namespace tfjs::wasm
15 changes: 8 additions & 7 deletions tfjs-backend-wasm/src/cc/pool3d_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ inline void NDHWCPool3DImpl(const IN* x_buf, OUT* out_buf,
x_col += info.dilation_width) {
int x_offset =
info.in_offset(batch, x_depth, x_row, x_col, channel);
filter_apply(filter_data, x_buf[x_offset]);
filter_apply(filter_data, x_offset, x_buf[x_offset]);
}
}
}
Expand Down Expand Up @@ -125,6 +125,8 @@ inline void NDHWCPool3DGradImpl(const DY* dy_buf, DX* dx_buf,
int dy_row_corner = dx_row - info.pad_top;
int dy_col_corner = dx_col - info.pad_left;

int dx_offset =
info.in_offset(batch, dx_depth, dx_row, dx_col, channel);
DX dot_prod = 0;
for (int w_depth = 0; w_depth < info.effective_filter_depth;
w_depth += info.dilation_depth) {
Expand All @@ -148,15 +150,14 @@ inline void NDHWCPool3DGradImpl(const DY* dy_buf, DX* dx_buf,
continue;
}

DY pixel = dy_buf[info.out_offset(batch, dy_depth, dy_row,
dy_col, channel)];
dot_prod += pixel * pixel_mask(dy_depth, dy_row, dy_col,
w_depth, w_row, w_col);
int dy_offset =
info.out_offset(batch, dy_depth, dy_row, dy_col, channel);
DY pixel = dy_buf[dy_offset];
dot_prod += pixel * pixel_mask(dy_offset, dx_offset);
}
}
}
dx_buf[info.in_offset(batch, dx_depth, dx_row, dx_col, channel)] =
dot_prod;
dx_buf[dx_offset] = dot_prod;
}
}
}
Expand Down
108 changes: 108 additions & 0 deletions tfjs-backend-wasm/src/kernels/MaxPool3DGrad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util, KernelConfig, KernelFunc, MaxPool3DGrad, MaxPool3DGradAttrs, MaxPool3DGradInputs, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

let wasmMaxPool3DGrad: (
xId: number, dyId: number, dxId: number, batchSize: number,
channelSize: number, inDepth: number, inHeight: number, inWidth: number,
outDepth: number, outHeight: number, outWidth: number, strideDepth: number,
strideHeight: number, strideWidth: number, dilationDepth: number,
dilationHeight: number, dilationWidth: number, effectiveFilterDepth: number,
effectiveFilterHeight: number, effectiveFilterWidth: number,
padFront: number, padTop: number, padLeft: number) => void;

function setup(backend: BackendWasm) {
wasmMaxPool3DGrad = backend.wasm.cwrap('MaxPool3DGrad', null, [
'number', // xId
'number', // dyId
'number', // dxId
'number', // batchSize
'number', // channelSize
'number', // inDepth
'number', // inHeight
'number', // inWidth
'number', // outDepth
'number', // outHeight
'number', // outWidth
'number', // strideDepth
'number', // strideHeight
'number', // strideWidth
'number', // dilationDepth
'number', // dilationHeight
'number', // dilationWidth
'number', // effectiveFilterDepth
'number', // effectiveFilterHeight
'number', // effectiveFilterWidth
'number', // padFront
'number', // padTop
'number', // padLeft
]);
}

export function maxPool3DGrad(args: {
inputs: MaxPool3DGradInputs,
attrs: MaxPool3DGradAttrs,
backend: BackendWasm,
}): TensorInfo {
const {inputs, backend, attrs} = args;
const {dy, input} = inputs;
const {filterSize, strides, pad, dimRoundingMode} = attrs;

const convInfo = backend_util.computePool3DInfo(
input.shape as [number, number, number, number, number], filterSize,
strides, /*dilations=*/1, pad, dimRoundingMode);
const dx = backend.makeOutput(input.shape, input.dtype);

wasmMaxPool3DGrad(
backend.dataIdMap.get(input.dataId).id,
backend.dataIdMap.get(dy.dataId).id,
backend.dataIdMap.get(dx.dataId).id,
convInfo.batchSize,
// Since Pool3D ops (MaxPool3D and MaxPool3D) support 3D filter only, in
// channels should always equal to out channels.
/*channelSize=*/convInfo.inChannels,
convInfo.inDepth,
convInfo.inHeight,
convInfo.inWidth,
convInfo.outDepth,
convInfo.outHeight,
convInfo.outWidth,
convInfo.strideDepth,
convInfo.strideHeight,
convInfo.strideWidth,
convInfo.dilationDepth,
convInfo.dilationHeight,
convInfo.dilationWidth,
convInfo.effectiveFilterDepth,
convInfo.effectiveFilterHeight,
convInfo.effectiveFilterWidth,
convInfo.padInfo.front,
convInfo.padInfo.top,
convInfo.padInfo.left,
);
return dx;
}

export const maxPool3DGradConfig: KernelConfig = {
kernelName: MaxPool3DGrad,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: maxPool3DGrad as unknown as KernelFunc
};
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ import {maxConfig} from './kernels/Max';
import {maximumConfig} from './kernels/Maximum';
import {maxPoolConfig} from './kernels/MaxPool';
import {maxPool3DConfig} from './kernels/MaxPool3D';
import {maxPool3DGradConfig} from './kernels/MaxPool3DGrad';
import {meanConfig} from './kernels/Mean';
import {minConfig} from './kernels/Min';
import {minimumConfig} from './kernels/Minimum';
Expand Down Expand Up @@ -204,6 +205,7 @@ const kernelConfigs: KernelConfig[] = [
maximumConfig,
maxPoolConfig,
maxPool3DConfig,
maxPool3DGradConfig,
meanConfig,
minConfig,
minimumConfig,
Expand Down
7 changes: 3 additions & 4 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ const TEST_FILTERS: TestFilter[] = [
{
include: 'maxPool',
excludes: [
'maxPoolBackprop', // Not yet implemented.
'maxPool3dBackprop', // Not yet implemented.
'ignores NaNs', // Actual != expected.
'maxPoolWithArgmax' // Not yet implemented.
'maxPoolBackprop', // Not yet implemented.
'ignores NaNs', // Actual != expected.
'maxPoolWithArgmax' // Not yet implemented.
]
},
{include: 'cropAndResize'},
Expand Down

0 comments on commit 3f25c9d

Please sign in to comment.