Skip to content
This repository has been archived by the owner on Oct 17, 2021. It is now read-only.

Change spies to conform with window.fetch API. #434

Merged
merged 4 commits into from
Jan 25, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 165 additions & 96 deletions src/models_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@
* =============================================================================
*/

import {DataType, io, ones, randomNormal, Scalar, scalar, serialization, sum, Tensor, tensor1d, tensor2d, zeros, tensor3d} from '@tensorflow/tfjs-core';
import {DataType, io, ones, randomNormal, Scalar, scalar, serialization, sum, Tensor, tensor1d, tensor2d, tensor3d, zeros} from '@tensorflow/tfjs-core';
import {ConfigDict} from '@tensorflow/tfjs-core/dist/serialization';

import {Model} from './engine/training';
import * as tfl from './index';
import {PyJsonDict} from './keras_format/types';
import {Reshape} from './layers/core';
import {deserialize} from './layers/serialization';
import {loadModelInternal, ModelAndWeightsConfig, modelFromJSON} from './models';
import {convertPythonicToTs, convertTsToPythonic} from './utils/serialization_utils';
import {describeMathCPU, describeMathCPUAndGPU, expectTensorsClose} from './utils/test_utils';
import {version as layersVersion} from './version';
import {PyJsonDict} from './keras_format/types';

const OCTET_STREAM_TYPE = 'application/octet-stream';
const JSON_TYPE = 'application/json';

describeMathCPU('Nested model topology', () => {
it('Nested Sequential model: Sequential as first layer', done => {
Expand Down Expand Up @@ -215,7 +217,8 @@ describeMathCPU('Nested model topology', () => {
const outerModel = tfl.sequential({
layers: [
innerModel,
tfl.layers.dense({units: 1, kernelInitializer: 'zeros', useBias: false})
tfl.layers.dense(
{units: 1, kernelInitializer: 'zeros', useBias: false})
]
});

Expand Down Expand Up @@ -545,9 +548,12 @@ describeMathCPU('loadModel from URL', () => {
const setupFakeWeightFiles =
(fileBufferMap:
{[filename: string]: Float32Array|Int32Array|ArrayBuffer}) => {
spyOn(window, 'fetch').and.callFake((path: string) => {
return new Response(fileBufferMap[path]);
});
spyOn(window, 'fetch')
.and.callFake((path: string) => new Promise(resolve => {
resolve(new Response(fileBufferMap[path], {
'headers': {'Content-Type': OCTET_STREAM_TYPE}
}));
}));
};

const isModelConfigNestedValues = [false, true];
Expand Down Expand Up @@ -623,20 +629,28 @@ describeMathCPU('loadModel from URL', () => {
[{'name': `dense_6/bias`, 'dtype': 'float32', 'shape': [32]}],
}
];

spyOn(window, 'fetch').and.callFake((path: string) => {
if (path === 'model/model.json') {
return new Response(JSON.stringify({
modelTopology,
weightsManifest,
}));
} else if (path === 'model/weight_0') {
return new Response(
ones([32, 32], 'float32').dataSync() as Float32Array);
} else if (path === 'model/weight_1') {
return new Response(zeros([32], 'float32').dataSync() as Float32Array);
} else {
throw new Error(`Invalid path: ${path}`);
}
return new Promise((resolve, reject) => {
if (path === 'model/model.json') {
resolve(new Response(
JSON.stringify({
modelTopology,
weightsManifest,
}),
{'headers': {'Content-Type': JSON_TYPE}}));
} else if (path === 'model/weight_0') {
resolve(new Response(
ones([32, 32], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === 'model/weight_1') {
resolve(new Response(
zeros([32], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else {
reject(new Error(`Invalid path: ${path}`));
}
});
});

const model = await loadModelInternal('model/model.json');
Expand Down Expand Up @@ -681,25 +695,34 @@ describeMathCPU('loadModel from URL', () => {
}
];
spyOn(window, 'fetch').and.callFake((path: string) => {
if (path === 'model/model.json') {
return new Response(JSON.stringify({
modelTopology,
weightsManifest,
}));
} else if (path === 'model/weight_0') {
return new Response(
ones([10, 2], 'float32').dataSync() as Float32Array);
} else if (path === 'model/weight_1') {
return new Response(
zeros([2], 'float32').dataSync() as Float32Array);
} else if (path === 'model/weight_2') {
return new Response(
zeros([2, 1], 'float32').dataSync() as Float32Array);
} else if (path === 'model/weight_3') {
return new Response(ones([1], 'float32').dataSync() as Float32Array);
} else {
throw new Error(`Invalid path: ${path}`);
}
return new Promise((resolve, reject) => {
if (path === 'model/model.json') {
resolve(new Response(
JSON.stringify({
modelTopology,
weightsManifest,
}),
{'headers': {'Content-Type': JSON_TYPE}}));
} else if (path === 'model/weight_0') {
resolve(new Response(
ones([10, 2], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === 'model/weight_1') {
resolve(new Response(
zeros([2], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === 'model/weight_2') {
resolve(new Response(
zeros([2, 1], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === 'model/weight_3') {
resolve(new Response(
ones([1], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else {
reject(new Error(`Invalid path: ${path}`));
}
});
});

const model = await loadModelInternal('model/model.json');
Expand Down Expand Up @@ -741,24 +764,34 @@ describeMathCPU('loadModel from URL', () => {
}
];
spyOn(window, 'fetch').and.callFake((path: string) => {
if (path === 'model/model.json') {
return new Response(JSON.stringify({
modelTopology,
weightsManifest,
}));
} else if (path === 'model/weight_0') {
return new Response(
ones([10, 2], 'float32').dataSync() as Float32Array);
} else if (path === 'model/weight_1') {
return new Response(zeros([2], 'float32').dataSync() as Float32Array);
} else if (path === 'model/weight_2') {
return new Response(
zeros([2, 1], 'float32').dataSync() as Float32Array);
} else if (path === 'model/weight_3') {
return new Response(ones([1], 'float32').dataSync() as Float32Array);
} else {
throw new Error(`Invalid path: ${path}`);
}
return new Promise((resolve, reject) => {
if (path === 'model/model.json') {
resolve(new Response(
JSON.stringify({
modelTopology,
weightsManifest,
}),
{'headers': {'Content-Type': JSON_TYPE}}));
} else if (path === 'model/weight_0') {
resolve(new Response(
ones([10, 2], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === 'model/weight_1') {
resolve(new Response(
zeros([2], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === 'model/weight_2') {
resolve(new Response(
zeros([2, 1], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === 'model/weight_3') {
resolve(new Response(
ones([1], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else {
reject(new Error(`Invalid path: ${path}`));
}
});
});

const model = await loadModelInternal('model/model.json');
Expand Down Expand Up @@ -795,28 +828,34 @@ describeMathCPU('loadModel from URL', () => {
}
];

const requestHeaders: Array<{}> = [];
const requestHeaders: Array<{[key: string]: string | {}}> = [];
const requestCredentials: string[] = [];
spyOn(window, 'fetch')
.and.callFake((path: string, requestInit?: RequestInit) => {
if (requestInit != null) {
requestHeaders.push(requestInit.headers);
requestHeaders.push(requestInit.headers as {});
requestCredentials.push(requestInit.credentials);
}
if (path === 'model/model.json') {
return new Response(JSON.stringify({
modelTopology,
weightsManifest,
}));
} else if (path === 'model/weight_0') {
return new Response(
ones([32, 32], 'float32').dataSync() as Float32Array);
} else if (path === 'model/weight_1') {
return new Response(
zeros([32], 'float32').dataSync() as Float32Array);
} else {
throw new Error(`Invalid path: ${path}`);
}
return new Promise((resolve, reject) => {
if (path === 'model/model.json') {
resolve(new Response(
JSON.stringify({
modelTopology,
weightsManifest,
}),
{'headers': {'Content-Type': JSON_TYPE}}));
} else if (path === 'model/weight_0') {
resolve(new Response(
ones([32, 32], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === 'model/weight_1') {
resolve(new Response(
zeros([32], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else {
reject(new Error(`Invalid path: ${path}`));
}
});
});

const model =
Expand All @@ -836,10 +875,30 @@ describeMathCPU('loadModel from URL', () => {

// Verify that the headers and credentials are sent via
// `fetch` properly.
expect(requestHeaders).toEqual([
{'header_key_1': 'header_value_1'}, {'header_key_1': 'header_value_1'},
{'header_key_1': 'header_value_1'}
]);
expect(requestHeaders[0]).toEqual(jasmine.objectContaining({
'header_key_1': 'header_value_1'
}));
if (requestHeaders[0]['Accept']) {
expect(requestHeaders[0]).toEqual(jasmine.objectContaining({
'Accept': 'application/json'
}));
}
expect(requestHeaders[1]).toEqual(jasmine.objectContaining({
'header_key_1': 'header_value_1'
}));
if (requestHeaders[1]['Accept']) {
expect(requestHeaders[1]).toEqual(jasmine.objectContaining({
'Accept': 'application/octet-stream'
}));
}
expect(requestHeaders[2]).toEqual(jasmine.objectContaining({
'header_key_1': 'header_value_1'
}));
if (requestHeaders[2]['Accept']) {
expect(requestHeaders[2]).toEqual(jasmine.objectContaining({
'Accept': 'application/octet-stream'
}));
}
expect(requestCredentials).toEqual(['include', 'include', 'include']);
});

Expand All @@ -852,9 +911,11 @@ describeMathCPU('loadModel from URL', () => {
const weightsManifest: io.WeightsManifestConfig = [
{
'paths': ['weight_0'],
'weights': [
{'name': `dense_6/kernel`, 'dtype': 'float32', 'shape': [32, 32]}
],
'weights': [{
'name': `dense_6/kernel`,
'dtype': 'float32',
'shape': [32, 32]
}],
},
{
'paths': ['weight_1'],
Expand All @@ -863,20 +924,26 @@ describeMathCPU('loadModel from URL', () => {
}
];
spyOn(window, 'fetch').and.callFake((path: string) => {
if (path === `${protocol}localhost:8888/models/model.json`) {
return new Response(JSON.stringify({
modelTopology,
weightsManifest,
}));
} else if (path === `${protocol}localhost:8888/models/weight_0`) {
return new Response(
ones([32, 32], 'float32').dataSync() as Float32Array);
} else if (path === `${protocol}localhost:8888/models/weight_1`) {
return new Response(
zeros([32], 'float32').dataSync() as Float32Array);
} else {
throw new Error(`Invalid path: ${path}`);
}
return new Promise((resolve, reject) => {
if (path === `${protocol}localhost:8888/models/model.json`) {
resolve(new Response(
JSON.stringify({
modelTopology,
weightsManifest,
}),
{'headers': {'Content-Type': JSON_TYPE}}));
} else if (path === `${protocol}localhost:8888/models/weight_0`) {
resolve(new Response(
ones([32, 32], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else if (path === `${protocol}localhost:8888/models/weight_1`) {
resolve(new Response(
zeros([32], 'float32').dataSync() as Float32Array,
{'headers': {'Content-Type': OCTET_STREAM_TYPE}}));
} else {
reject(new Error(`Invalid path: ${path}`));
}
});
});

const model = await loadModelInternal(
Expand Down Expand Up @@ -937,9 +1004,11 @@ describeMathCPU('loadModel from URL', () => {
},
{
'paths': ['weight_1'],
'weights': [
{'name': `${denseLayerName}/bias`, 'dtype': 'float32', 'shape': [32]}
],
'weights': [{
'name': `${denseLayerName}/bias`,
'dtype': 'float32',
'shape': [32]
}],
}
];
// JSON.parse and stringify to deep copy fakeSequentialModel.
Expand Down