Skip to content

Commit

Permalink
[webgpu] Get device supported limits info from WebGPU API (#5833)
Browse files Browse the repository at this point in the history
* [webgpu] Get device supported limits info by WebGPU API
  • Loading branch information
haoyunfeix authored Feb 23, 2022
1 parent aa2c7e1 commit 6a0a2e8
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 51 deletions.
30 changes: 30 additions & 0 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,34 @@ export interface WebGPUTimingInfo extends TimingInfo {
const CPU_HANDOFF_SIZE_THRESHOLD =
env().getNumber('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD');

// Reshape dispatch, not to exceed device limits.
const reshapeDispatch = (device: GPUDevice,
program: webgpu_program.WebGPUProgram): [number, number, number] => {
const MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE =
device.limits.maxComputeWorkgroupsPerDimension;
const layout = program['dispatchLayout'];
const dispatch = program['dispatch'];
if (dispatch.every((d) => d <= MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE)) {
return dispatch;
}

util.assert(
dispatch[0] > MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE &&
layout.y === undefined && layout.z === undefined,
() => 'Dispatch size exceeds WebGPU limits in Y or Z dimension.');

let dispatchAverage = Math.ceil(Math.sqrt(dispatch[0]));
if (dispatchAverage > MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE) {
dispatchAverage = Math.ceil(Math.cbrt(dispatch[0]));
util.assert(
dispatchAverage <= MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE,
() => 'Total dispatch size exceeds WebGPU maximum.');
return [dispatchAverage, dispatchAverage, dispatchAverage];
} else {
return [dispatchAverage, dispatchAverage, 1];
}
};

export class WebGPUBackend extends KernelBackend {
device: GPUDevice;
queue: GPUQueue;
Expand Down Expand Up @@ -697,6 +725,7 @@ export class WebGPUBackend extends KernelBackend {
}
this.uploadToGPU(output.dataId);
}
program.dispatch = reshapeDispatch(this.device, program);

// There are five kinds of uniforms: NAN, shapes, shape strides, program
// size, program defined uniforms.
Expand Down Expand Up @@ -811,6 +840,7 @@ export class WebGPUBackend extends KernelBackend {
runFromPixelsProgram(
program: FromPixelsProgram, output: GPUBuffer, layout: WebGPULayout,
externalResource: GPUExternalTexture|GPUTextureView, outputId: DataId) {
program.dispatch = reshapeDispatch(this.device, program);
const bindGroup = this.device.createBindGroup({
layout: layout.bindGroupLayout,
entries: [
Expand Down
26 changes: 0 additions & 26 deletions tfjs-backend-webgpu/src/constants.ts

This file was deleted.

11 changes: 9 additions & 2 deletions tfjs-backend-webgpu/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,18 @@ if (isWebGPUSupported()) {
};

const adapter = await navigator.gpu.requestAdapter(gpuDescriptor);
let deviceDescriptor: GPUDeviceDescriptor = {};
const adapterLimits = adapter.limits;
const deviceDescriptor: GPUDeviceDescriptor = {};
const supportTimeQuery = adapter.features.has('timestamp-query');
deviceDescriptor.requiredLimits = {
'maxComputeWorkgroupStorageSize':
adapterLimits.maxComputeWorkgroupStorageSize,
'maxComputeWorkgroupsPerDimension':
adapterLimits.maxComputeWorkgroupsPerDimension,
};

if (supportTimeQuery) {
deviceDescriptor = {requiredFeatures: ['timestamp-query' as const]};
deviceDescriptor.requiredFeatures = ['timestamp-query' as const];
} else {
console.warn(
`This device doesn't support timestamp-query extension. ` +
Expand Down
25 changes: 2 additions & 23 deletions tfjs-backend-webgpu/src/webgpu_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
* limitations under the License.
* =============================================================================
*/
import {DataType, util} from '@tensorflow/tfjs-core';

import {MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE} from './constants';
import {DataType} from '@tensorflow/tfjs-core';

const arrayProduct = (arr: number[]) => {
let product = 1;
Expand Down Expand Up @@ -58,26 +56,7 @@ export function computeDispatch(
(workGroupSize[2] * elementsPerThread[2])) :
1
];

if (dispatchX <= MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE &&
dispatchY <= MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE &&
dispatchZ <= MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE) {
return [dispatchX, dispatchY, dispatchZ];
}

util.assert(dispatchX > MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE &&
layout.y === undefined && layout.z === undefined, () =>
'Dispatch size exceeds WebGPU limits in Y or Z dimension.');

let dispatchAverage = Math.ceil(Math.sqrt(dispatchX));
if (dispatchAverage > MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE) {
dispatchAverage = Math.ceil(Math.cbrt(dispatchX));
util.assert(dispatchAverage <= MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE,
() => 'Total dispatch size exceeds WebGPU maximum.');
return [dispatchAverage, dispatchAverage, dispatchAverage];
} else {
return [dispatchAverage, dispatchAverage, 1];
}
return [dispatchX, dispatchY, dispatchZ];
}

export function computeWorkGroupSizeForConv2d(
Expand Down

0 comments on commit 6a0a2e8

Please sign in to comment.