Skip to content

Commit

Permalink
makeModifiedInput to correctly set _id and _meta on new examples.
Browse files Browse the repository at this point in the history
Fixes a bug where modified examples still had the original ID, and thus were
returning cached predictions from the original example.

 - Fix manual-example-creation flow, which was incorrectly inheriting the parent ID.
 - Fix LM prediction module and TDA module which use ephemeral examples with
   empty ID.

PiperOrigin-RevId: 552580039
  • Loading branch information
iftenney authored and LIT team committed Jul 31, 2023
1 parent 7b96a6d commit 0ec1527
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 64 deletions.
15 changes: 13 additions & 2 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,9 @@ def _annotate_new_data(self,

# Add annotations and IDs to new datapoints.
for i, example in enumerate(data['inputs']):
example['data'] = annotated_dataset.examples[i]
example['id'] = caching.input_hash(example['data'])
new_id = caching.input_hash(example['data'])
example['data'] = dict(annotated_dataset.examples[i], _id=new_id)
example['id'] = new_id

return data['inputs'] # pytype: disable=bad-return-type # always-use-return-annotations

Expand Down Expand Up @@ -817,6 +818,16 @@ def _handler(app: wsgi_app.App, request, environ):
data['inputs'] = self._reconstitute_inputs(
data['inputs'], kw['dataset_name']
)
# Validate that id and data._id match.
# TODO(b/171513556): consider removing this if we can simplify the
# data representation on the frontend so id and meta are not replicated.
for ex in data['inputs']:
if ex['id'] != ex['data'].get('_id'):
raise ValueError(
'Error: malformed example with inconsistent ids:'
f' {str(ex)}\nfrom request'
f' {request.path} {str(request.args.to_dict())}'
)

outputs = fn(data, **kw)
response_body = serialize.to_json(outputs, simple=response_simple_json)
Expand Down
25 changes: 22 additions & 3 deletions lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
// tslint:disable: enforce-name-casing

import * as d3 from 'd3'; // Used for array helpers.

import {unsafeHTML} from 'lit/directives/unsafe-html.js';

import {marked} from 'marked';
import {LitName, LitType, LitTypeTypesList, LitTypeWithParent, MulticlassPreds, LIT_TYPES_REGISTRY} from './lit_types';
import {CallConfig, FacetMap, ModelInfoMap, Spec} from './types';
import {LIT_TYPES_REGISTRY, LitName, LitType, LitTypeTypesList, LitTypeWithParent, MulticlassPreds} from './lit_types';
import {CallConfig, FacetMap, IndexedInput, ModelInfoMap, Spec} from './types';

/** Calculates the mean for a list of numbers */
export function mean(values: number[]): number {
Expand Down Expand Up @@ -603,3 +602,23 @@ export function validateCallConfig(
litType.required ? config[key] == null : false)
.map(([key, unused]) => key);
}

/**
* Make a modified input if any of the overrides would change the data.
*/
export function makeModifiedInput(
input: IndexedInput, overrides: {[key: string]: unknown},
source?: string): IndexedInput {
for (const key of Object.keys(overrides)) {
if (input.data[key] !== overrides[key]) {
const newMeta = {added: true, source, parentId: input.id};
return {
data: Object.assign(
{}, input.data, overrides, {'_id': '', '_meta': newMeta}),
id: '',
meta: newMeta
};
}
}
return input;
}
79 changes: 78 additions & 1 deletion lit_nlp/client/lib/utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import 'jasmine';

import {CategoryLabel, GeneratedText, LitType, MulticlassPreds, RegressionScore, Scalar, StringLitType, TextSegment, TokenGradients} from '../lib/lit_types';
import {CallConfig, Spec} from '../lib/types';
import {CallConfig, IndexedInput, Spec} from '../lib/types';

import * as utils from './utils';

Expand Down Expand Up @@ -766,3 +766,80 @@ describe('validateCallConfig test', () => {
});
});
});


describe('makeModifiedInput test', () => {
const testInput: IndexedInput = {
'data': {
'foo': 123,
'bar': 234,
'_id': 'a1b2c3',
'_meta': {parentId: '000000'}
},
'id': 'a1b2c3',
'meta': {parentId: '000000'}
};

it('overrides one field', () => {
const modifiedInput =
utils.makeModifiedInput(testInput, {'bar': 345}, 'testFn');
const expectedNewMeta = {
added: true,
source: 'testFn',
parentId: testInput.id
};
expect(modifiedInput).toEqual({
'data': {'foo': 123, 'bar': 345, '_id': '', '_meta': expectedNewMeta},
'id': '',
'meta': expectedNewMeta,
});
});

it('overrides two fields', () => {
const modifiedInput =
utils.makeModifiedInput(testInput, {'foo': 234, 'bar': 345}, 'testFn');
const expectedNewMeta = {
added: true,
source: 'testFn',
parentId: testInput.id
};
expect(modifiedInput).toEqual({
'data': {'foo': 234, 'bar': 345, '_id': '', '_meta': expectedNewMeta},
'id': '',
'meta': expectedNewMeta,
});
});

it('adds a field with a new name', () => {
const modifiedInput =
utils.makeModifiedInput(testInput, {'baz': 'spam and eggs'}, 'testFn');
const expectedNewMeta = {
added: true,
source: 'testFn',
parentId: testInput.id
};
expect(modifiedInput).toEqual({
'data': {
'foo': 123,
'bar': 234,
'baz': 'spam and eggs',
'_id': '',
'_meta': expectedNewMeta
},
'id': '',
'meta': expectedNewMeta,
});
});

it('returns original if no fields modified', () => {
// Overrides match original values, so should be a no-op.
const modifiedInput =
utils.makeModifiedInput(testInput, {'foo': 123, 'bar': 234}, 'testFn');
expect(modifiedInput).toEqual(testInput);
});

it('returns original if overrides empty', () => {
const modifiedInput = utils.makeModifiedInput(testInput, {}, 'testFn');
expect(modifiedInput).toEqual(testInput);
});
});
41 changes: 12 additions & 29 deletions lit_nlp/client/modules/datapoint_editor_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ import {customElement} from 'lit/decorators.js';
import {classMap} from 'lit/directives/class-map.js';
import {styleMap} from 'lit/directives/style-map.js';
import {computed, observable, when} from 'mobx';

import {app} from '../core/app';
import {LitModule} from '../core/lit_module';
import {AnnotationCluster, EdgeLabel, SpanLabel} from '../lib/dtypes';
import {BooleanLitType, EdgeLabels, Embeddings, ImageBytes, ListLitType, LitTypeWithVocab, MultiSegmentAnnotations, SearchQuery, SequenceTags, SpanLabels, SparseMultilabel, StringLitType, Tokens, URLLitType} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {defaultValueByField, formatAnnotationCluster, formatEdgeLabel, formatSpanLabel, IndexedInput, Input, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types';
import {findSpecKeys, isLitSubtype} from '../lib/utils';
import {formatAnnotationCluster, formatEdgeLabel, formatSpanLabel, IndexedInput, Input, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types';
import {findSpecKeys, isLitSubtype, makeModifiedInput} from '../lib/utils';
import {GroupService} from '../services/group_service';
import {SelectionService} from '../services/selection_service';

Expand Down Expand Up @@ -72,26 +73,16 @@ export class DatapointEditorModule extends LitModule {
protected showAddAndCompare = true;

@computed
get emptyDatapoint(): Input {
const data: Input = {};
const spec = this.appState.currentDatasetSpec;
for (const key of this.appState.currentInputDataKeys) {
data[key] = defaultValueByField(key, spec);
}
return data;
}

@computed
get baseData(): Input {
get baseData(): IndexedInput {
const input = this.selectionService.primarySelectedInputData;
return input == null ? this.emptyDatapoint : input.data;
return input ?? this.appState.makeEmptyDatapoint('manual');
}

@observable dataEdits: Input = {};

@computed
get editedData(): Input {
return Object.assign({}, this.baseData, this.dataEdits);
get editedData(): IndexedInput {
return makeModifiedInput(this.baseData, this.dataEdits, 'manual');
}

@computed
Expand Down Expand Up @@ -297,16 +288,8 @@ export class DatapointEditorModule extends LitModule {
const clearEnabled = this.selectionService.primarySelectedInputData != null;

const onClickNew = async () => {
const datum: IndexedInput = {
data: this.editedData,
id: '', // will be overwritten
meta: {
source: 'manual',
added: true,
parentId: this.selectionService.primarySelectedId!
},
};
const data: IndexedInput[] = await this.appState.annotateNewData([datum]);
const data: IndexedInput[] =
await this.appState.annotateNewData([this.editedData]);
this.appState.commitNewDatapoints(data);
const newIds = data.map(d => d.id);
this.selectionService.selectIds(newIds);
Expand Down Expand Up @@ -369,7 +352,7 @@ export class DatapointEditorModule extends LitModule {
return html`
<div id="edit-table">
${keys.map(
key => this.renderEntry(key, this.editedData[key]))}
key => this.renderEntry(key, this.editedData.data[key]))}
</div>
`;
// clang-format on
Expand All @@ -381,7 +364,7 @@ export class DatapointEditorModule extends LitModule {
(e: Event, converterFn: InputConverterFn = (s => s)) => {
// tslint:disable-next-line:no-any
const value = converterFn((e as any).target.value as string);
if (value === this.baseData[key]) {
if (value === this.baseData.data[key]) {
delete this.dataEdits[key];
} else {
this.dataEdits[key] = value;
Expand Down Expand Up @@ -429,7 +412,7 @@ export class DatapointEditorModule extends LitModule {
const reader = new FileReader();
reader.addEventListener('load', () => {
const value = reader.result as string;
if (value === this.baseData[key]) {
if (value === this.baseData.data[key]) {
delete this.dataEdits[key];
} else {
this.dataEdits[key] = value;
Expand Down
18 changes: 4 additions & 14 deletions lit_nlp/client/modules/lm_prediction_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@

import '../elements/checkbox';

import {html} from 'lit';
// tslint:disable:no-new-decorators
import {customElement} from 'lit/decorators.js';
import { html} from 'lit';
import {classMap} from 'lit/directives/class-map.js';
import {computed, observable} from 'mobx';

import {LitModule} from '../core/lit_module';
import {TextSegment, Tokens, TokenTopKPreds} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {IndexedInput, ModelInfoMap, Spec, TopKResult} from '../lib/types';
import {findMatchingIndices, findSpecKeys, replaceNth} from '../lib/utils';
import {findMatchingIndices, findSpecKeys, makeModifiedInput, replaceNth} from '../lib/utils';

import {styles} from './lm_prediction_module.css';
import {styles as sharedStyles} from '../lib/shared_styles.css';

/**
* A LIT module that renders masked predictions for a masked LM.
Expand Down Expand Up @@ -159,17 +159,7 @@ export class LanguageModelPredictionModule extends LitModule {
}

private createChildDatapoint(orig: IndexedInput, tokens: string[]) {
const inputData = Object.assign(
{}, orig.data, {[this.inputTokensKey!]: tokens});
return {
data: inputData,
id: '',
meta: {
added: true,
source: 'masked',
parentId: orig.id
}
};
return makeModifiedInput(orig, {[this.inputTokensKey!]: tokens}, 'masked');
}

private async updateMLMResults() {
Expand Down
10 changes: 2 additions & 8 deletions lit_nlp/client/modules/tda_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {canonicalizeGenerationResults, GeneratedTextResult, GENERATION_TYPES, ge
import {FieldMatcher, InfluentialExamples, LitTypeWithParent} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {CallConfig, ComponentInfoMap, IndexedInput, Input, ModelInfoMap, Spec} from '../lib/types';
import {cloneSpec, filterToKeys, findSpecKeys} from '../lib/utils';
import {cloneSpec, filterToKeys, findSpecKeys, makeModifiedInput} from '../lib/utils';
import {AppState, SelectionService} from '../services/services';

import {styles} from './tda_module.css';
Expand Down Expand Up @@ -189,13 +189,7 @@ export class TrainingDataAttributionModule extends LitModule {
if (this.currentData == null) return undefined;
if (!Object.keys(this.customLabels).length) return this.currentData;

const modifiedInputData =
Object.assign({}, this.currentData.data, this.customLabels);
return {
data: modifiedInputData,
id: '',
meta: {added: true, source: 'tda_custom', parentId: this.currentData.id}
};
return makeModifiedInput(this.currentData, this.customLabels, 'tda_custom');
}

static compatibleGenerators(generatorInfo: ComponentInfoMap): string[] {
Expand Down
16 changes: 14 additions & 2 deletions lit_nlp/client/services/state_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import {action, computed, observable, toJS} from 'mobx';

import {FieldMatcher, ImageBytes} from '../lib/lit_types';
import {IndexedInput, LitCanonicalLayout, LitComponentLayouts, LitMetadata, ModelInfo, ModelInfoMap, ModelSpec, Spec} from '../lib/types';
import {getTypes, findSpecKeys} from '../lib/utils';
import {defaultValueByField, IndexedInput, Input, LitCanonicalLayout, LitComponentLayouts, LitMetadata, ModelInfo, ModelInfoMap, ModelSpec, Spec} from '../lib/types';
import {findSpecKeys, getTypes} from '../lib/utils';

import {ApiService} from './api_service';
import {LitService} from './lit_service';
Expand Down Expand Up @@ -249,6 +249,18 @@ export class AppState extends LitService implements StateObservedByUrlService {
}

//=================================== Generation logic
/**
* Create an empty datapoint with appropriate default values for each field.
*/
makeEmptyDatapoint(source?: string) {
const data: Input = {'_id': '', '_meta': {source, added: true}};
const spec = this.currentDatasetSpec;
for (const key of this.currentInputDataKeys) {
data[key] = defaultValueByField(key, spec);
}
return {data, id: '', meta: data['_meta']};
}

/**
* Annotate one or more bare datapoints.
* @param data input examples; ids will be overwritten.
Expand Down
10 changes: 5 additions & 5 deletions lit_nlp/client/services/url_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {autorun} from 'mobx';

import {ListLitType} from '../lib/lit_types';
import {defaultValueByField, IndexedInput, Input, LitCanonicalLayout, LitMetadata, ServiceUser, Spec} from '../lib/types';
import {makeModifiedInput} from '../lib/utils';

import {LitService} from './lit_service';
import {ApiService} from './services';
Expand Down Expand Up @@ -63,6 +64,7 @@ export interface StateObservedByUrlService {
compareExamplesEnabled: boolean;
layoutName: string;
getCurrentInputDataById: (id: string) => IndexedInput | null;
makeEmptyDatapoint: (source?: string) => IndexedInput;
annotateNewData: (data: IndexedInput[]) => Promise<IndexedInput[]>;
commitNewDatapoints: (datapoints: IndexedInput[]) => void;
documentationOpen: boolean;
Expand Down Expand Up @@ -347,11 +349,9 @@ export class UrlService extends LitService {
Object.keys(spec).forEach(key => {
outputFields[key] = this.parseDataFieldValue(key, fields[key], spec);
});
const datum: IndexedInput = {
data: outputFields,
id: '', // will be overwritten
meta: {source: 'url', added: true},
};

const datum = makeModifiedInput(
appState.makeEmptyDatapoint(), outputFields, /* source */ 'url');
return datum;
});

Expand Down

0 comments on commit 0ec1527

Please sign in to comment.