Skip to content

Commit

Permalink
[webgpu] Fix storage buffer exceed error in Concat operator (#6532)
Browse files Browse the repository at this point in the history
Fixes #6507
  • Loading branch information
haoyunfeix authored Jun 29, 2022
1 parent ccca854 commit ca9799d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
21 changes: 12 additions & 9 deletions tfjs-backend-webgl/src/kernels/Concat_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,18 @@ export function concatImpl(
return outInfo;
}

if (inputs.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
const midIndex = Math.floor(inputs.length / 2);
const leftSide = concatImpl(inputs.slice(0, midIndex), axis, backend);
const rightSide = concatImpl(inputs.slice(midIndex), axis, backend);

const result = concatImpl([leftSide, rightSide], axis, backend);

backend.disposeIntermediateTensorInfo(leftSide);
backend.disposeIntermediateTensorInfo(rightSide);
const maxTexturesInShader = env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER');
if (inputs.length > maxTexturesInShader) {
const reducedInputs = [];
for (let i = 0; i < inputs.length; i += maxTexturesInShader) {
const subArray = inputs.slice(i, i + maxTexturesInShader);
reducedInputs.push(concatImpl(subArray, axis, backend));
}
const result = concatImpl(reducedInputs, axis, backend);

for (const i of reducedInputs) {
backend.disposeIntermediateTensorInfo(i);
}

return result;
}
Expand Down
20 changes: 19 additions & 1 deletion tfjs-backend-webgpu/src/kernels/Concat_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import {backend_util, ConcatInputs, TensorInfo, util} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {ConcatProgram} from '../concat_webgpu';
import {concatImplCPU} from '../kernel_utils/shared';

import {complex} from './Complex';
import {ConcatProgram} from '../concat_webgpu';
import {imag} from './Imag';
import {real} from './Real';
import {reshape} from './Reshape';
Expand Down Expand Up @@ -94,6 +94,24 @@ export function concatImpl(
return outInfo;
}

// There is a storage buffer limitation in compute stage, one for output so
// the maximum for input is limits.maxStorageBuffersPerShaderStage - 1
const maxInputNum = backend.device.limits.maxStorageBuffersPerShaderStage - 1;
if (inputs.length > maxInputNum) {
const reducedInputs = [];
for (let i = 0; i < inputs.length; i += maxInputNum) {
const subArray = inputs.slice(i, i + maxInputNum);
reducedInputs.push(concatImpl(subArray, axis, backend));
}
const result = concatImpl(reducedInputs, axis, backend);

for (const i of reducedInputs) {
backend.disposeData(i.dataId);
}

return result;
}

const {tensors2D, outShape} = computeTensors2D(inputs, axis, backend);
const shapes = (tensors2D).map(t => t.shape as [number, number]);
const program = new ConcatProgram(shapes);
Expand Down
9 changes: 0 additions & 9 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ const TEST_FILTERS: TestFilter[] = [
'accepts a tensor-like object', // tensor6d not yet implemented
]
},
{
startsWith: 'concat ',
excludes: [
'concat a large number of tensors', // The number of storage buffers
// exceeds the maximum per-stage
// limit.
]
},
{
startsWith: 'conv2d ',
excludes: [
Expand Down Expand Up @@ -291,7 +283,6 @@ const TEST_FILTERS: TestFilter[] = [
'avgPool3dBackprop ',
'bincount ',
'broadcastArgs ',
'concat3d ',
'conv2dTranspose ',
'conv2DBackpropFilter ',
'gradient with clones, input=2x2x1,d2=1,f=1,s=1,d=1,p=same', // Conv2DBackpropFilter
Expand Down

0 comments on commit ca9799d

Please sign in to comment.