Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
ensure webgl backend read downloads converted typedarray (#1382)
Browse files Browse the repository at this point in the history
Fixes a bug where async downloading of a tensor that was uploaded to the GPU would always return a Float32Array.

BUG
  • Loading branch information
tafsiri authored Nov 7, 2018
1 parent 10ebb98 commit 0e01701
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
19 changes: 9 additions & 10 deletions src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ export class MathBackendWebGL implements KernelBackend {
const texData = this.texData.get(dataId);
const {values, dtype, complexTensors} = texData;
if (values != null) {
this.cacheOnCPU(dataId);
return values;
return this.convertAndCacheOnCPU(dataId);
}
const shouldTimeProgram = this.activeTimers != null;
let start: number;
Expand All @@ -298,8 +297,7 @@ export class MathBackendWebGL implements KernelBackend {
if (shouldTimeProgram) {
this.downloadWaitMs += performance.now() - start;
}
this.cacheOnCPU(dataId, result);
return texData.values;
return this.convertAndCacheOnCPU(dataId, result);
}

async read(dataId: DataId): Promise<TypedArray> {
Expand All @@ -310,8 +308,7 @@ export class MathBackendWebGL implements KernelBackend {
const texData = this.texData.get(dataId);
const {texture, values, texShape} = texData;
if (values != null) {
this.cacheOnCPU(dataId);
return values;
return this.convertAndCacheOnCPU(dataId);
}

this.pendingRead.set(dataId, []);
Expand All @@ -338,18 +335,18 @@ export class MathBackendWebGL implements KernelBackend {
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(
bufferOrTexture, texShape[0], texShape[1]);
}
this.cacheOnCPU(dataId, vals);
const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);

const subscribers = this.pendingRead.get(dataId);
this.pendingRead.delete(dataId);

// Notify all pending reads.
subscribers.forEach(resolve => resolve(vals));
subscribers.forEach(resolve => resolve(dTypeVals));
if (this.pendingDisposal.has(dataId)) {
this.pendingDisposal.delete(dataId);
this.disposeData(dataId);
}
return vals;
return dTypeVals;
}

private getValuesFromTexture(dataId: DataId): Float32Array {
Expand Down Expand Up @@ -1935,7 +1932,8 @@ export class MathBackendWebGL implements KernelBackend {
}
}

private cacheOnCPU(dataId: DataId, float32Values?: Float32Array) {
private convertAndCacheOnCPU(dataId: DataId, float32Values?: Float32Array):
TypedArray {
// In delayed storage mode, when the user reads data, we don't keep a
// copy on the gpu, to minimize likelihood of memory leak. We re-upload
// to gpu the next time a gpgpu program needs the texture.
Expand All @@ -1951,6 +1949,7 @@ export class MathBackendWebGL implements KernelBackend {
if (float32Values != null) {
texData.values = float32ToTypedArray(float32Values, dtype);
}
return texData.values;
}

private releaseTexture(
Expand Down
29 changes: 28 additions & 1 deletion src/tensor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ describeWithFlags('tensor', ALL_ENVS, () => {

it('float32 dtype from boolean[]', () => {
const a = tf.tensor3d(
[[[false], [false]], [[true], [false]]], [2, 2, 1], 'float32');
[[[false], [false]], [[true], [false]]], [2, 2, 1], 'float32');
expect(a.dtype).toBe('float32');
expectArraysClose(a, [0, 0, 1, 0]);
});
Expand Down Expand Up @@ -984,6 +984,33 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(a.cast('int32').dtype).toEqual('int32');
});

it('cast float32 -> int32. async download', async () => {
const a = tf.tensor1d([1, 2]);
const aInt = a.cast('int32');
expect(aInt.dtype).toEqual('int32');

const asyncData = await aInt.data();
expect(asyncData instanceof Int32Array).toEqual(true);
});

it('cast float32 -> int32. queued async download', async () => {
const a = tf.tensor1d([1, 2]);
const aInt = a.cast('int32');
expect(aInt.dtype).toEqual('int32');

const [first, second] = await Promise.all([aInt.data(), aInt.data()]);
expect(first instanceof Int32Array).toEqual(true);
expect(second instanceof Int32Array).toEqual(true);
});

it('cast float32 -> int32. sync download', () => {
const a = tf.tensor1d([1, 2]).cast('int32');
expect(a.dtype).toEqual('int32');

const data = a.dataSync();
expect(data instanceof Int32Array).toEqual(true);
});

it('cast float32 -> float32', () => {
const a = tf.tensor1d([1.0, 2.0]);
expect(a.cast('float32').dtype).toEqual('float32');
Expand Down

0 comments on commit 0e01701

Please sign in to comment.