Skip to content

Commit

Permalink
Create LitTypes from a generic LitType constructor instead of a strin…
Browse files Browse the repository at this point in the history
…g LitName.

PiperOrigin-RevId: 465625447
  • Loading branch information
cjqian authored and LIT team committed Aug 5, 2022
1 parent cb528f1 commit a36a936
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 130 deletions.
10 changes: 5 additions & 5 deletions lit_nlp/client/lib/generated_text_utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,29 @@
import 'jasmine';

import {canonicalizeGenerationResults, getAllOutputTexts, getAllReferenceTexts, getFlatTexts, getTextDiff} from './generated_text_utils';
import {LitType} from './lit_types';
import {GeneratedText, GeneratedTextCandidates, LitType, ReferenceTexts, TextSegment} from './lit_types';
import {Input, Preds, Spec} from './types';
import {createLitType} from './utils';

function textSegmentType(): LitType {
return createLitType('TextSegment', {
return createLitType(TextSegment, {
'required': false,
});
}

function referenceTextsType(): LitType {
return createLitType('ReferenceTexts', {
return createLitType(ReferenceTexts, {
'required': false,
});
}

function generatedTextType(parent: string): LitType {
return createLitType('GeneratedText', {'required': false, 'parent': parent});
return createLitType(GeneratedText, {'required': false, 'parent': parent});
}

function generatedTextCandidatesType(parent: string): LitType {
return createLitType(
'GeneratedTextCandidates', {'required': false, 'parent': parent});
GeneratedTextCandidates, {'required': false, 'parent': parent});
}

describe('canonicalizeGenerationResults test', () => {
Expand Down
9 changes: 4 additions & 5 deletions lit_nlp/client/lib/lit_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
* A dictionary of registered LitType names mapped to their constructor.
* LitTypes are added using the @registered decorator.
*/
export const REGISTRY: {[litType: string]: LitType} = {};
// tslint:disable-next-line:no-any
function registered(target: any) {
REGISTRY[target.name] = target;
export const LIT_TYPES_REGISTRY: {[litType: string]: new () => LitType} = {};
function registered(target: new () => LitType) {
LIT_TYPES_REGISTRY[target.name] = target;
}

const registryKeys : string[] = Object.keys(REGISTRY);
const registryKeys : string[] = Object.keys(LIT_TYPES_REGISTRY);
/**
* The types of all LitTypes in the registry, e.g.
* 'StringLitType' | 'TextSegment' ...
Expand Down
71 changes: 36 additions & 35 deletions lit_nlp/client/lib/testing_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import 'jasmine';

import {AttentionHeads, BooleanLitType, CategoryLabel, Embeddings, MulticlassPreds, Scalar, TextSegment, TokenGradients, Tokens} from './lit_types';
import {LitMetadata, SerializedLitMetadata} from './types';
import {createLitType} from './utils';

Expand Down Expand Up @@ -56,25 +57,25 @@ export const mockMetadata: LitMetadata = {
'sst_0_micro': {
'spec': {
'input': {
'passage': createLitType('TextSegment'),
'passage': createLitType(TextSegment),
'passage_tokens':
createLitType('Tokens', {'required': false, 'parent': 'passage'}),
createLitType(Tokens, {'required': false, 'parent': 'passage'}),
},
'output': {
'probabilities': createLitType(
'MulticlassPreds',
MulticlassPreds,
{'vocab': ['0', '1'], 'null_idx': 0, 'parent': 'label'}),
'pooled_embs': createLitType('Embeddings'),
'mean_word_embs': createLitType('Embeddings'),
'tokens': createLitType('Tokens'),
'passage_tokens': createLitType('Tokens', {'parent': 'passage'}),
'pooled_embs': createLitType(Embeddings),
'mean_word_embs': createLitType(Embeddings),
'tokens': createLitType(Tokens),
'passage_tokens': createLitType(Tokens, {'parent': 'passage'}),
'passage_grad':
createLitType('TokenGradients', {'align': 'passage_tokens'}),
'layer_0/attention': createLitType('AttentionHeads', {
createLitType(TokenGradients, {'align': 'passage_tokens'}),
'layer_0/attention': createLitType(AttentionHeads, {
'align_in': 'tokens',
'align_out': 'tokens',
}),
'layer_1/attention': createLitType('AttentionHeads', {
'layer_1/attention': createLitType(AttentionHeads, {
'align_in': 'tokens',
'align_out': 'tokens',
}),
Expand All @@ -89,25 +90,25 @@ export const mockMetadata: LitMetadata = {
'sst_1_micro': {
'spec': {
'input': {
'passage': createLitType('TextSegment'),
'passage': createLitType(TextSegment),
'passage_tokens':
createLitType('Tokens', {'required': false, 'parent': 'passage'}),
createLitType(Tokens, {'required': false, 'parent': 'passage'}),
},
'output': {
'probabilities': createLitType(
'MulticlassPreds',
MulticlassPreds,
{'vocab': ['0', '1'], 'null_idx': 0, 'parent': 'label'}),
'pooled_embs': createLitType('Embeddings'),
'mean_word_embs': createLitType('Embeddings'),
'tokens': createLitType('Tokens'),
'passage_tokens': createLitType('Tokens', {'parent': 'passage'}),
'pooled_embs': createLitType(Embeddings),
'mean_word_embs': createLitType(Embeddings),
'tokens': createLitType(Tokens),
'passage_tokens': createLitType(Tokens, {'parent': 'passage'}),
'passage_grad':
createLitType('TokenGradients', {'align': 'passage_tokens'}),
'layer_0/attention': createLitType('AttentionHeads', {
createLitType(TokenGradients, {'align': 'passage_tokens'}),
'layer_0/attention': createLitType(AttentionHeads, {
'align_in': 'tokens',
'align_out': 'tokens',
}),
'layer_1/attention': createLitType('AttentionHeads', {
'layer_1/attention': createLitType(AttentionHeads, {
'align_in': 'tokens',
'align_out': 'tokens',
})
Expand All @@ -124,47 +125,47 @@ export const mockMetadata: LitMetadata = {
'sst_dev': {
'size': 872,
'spec': {
'passage': createLitType('TextSegment'),
'label': createLitType('CategoryLabel', {'vocab': ['0', '1']}),
'passage': createLitType(TextSegment),
'label': createLitType(CategoryLabel, {'vocab': ['0', '1']}),
}
},
'color_test': {
'size': 2,
'spec': {
'testNumFeat0': createLitType('Scalar'),
'testNumFeat1': createLitType('Scalar'),
'testFeat0': createLitType('CategoryLabel', {'vocab': ['0', '1']}),
'testFeat1': createLitType('CategoryLabel', {'vocab': ['a', 'b', 'c']})
'testNumFeat0': createLitType(Scalar),
'testNumFeat1': createLitType(Scalar),
'testFeat0': createLitType(CategoryLabel, {'vocab': ['0', '1']}),
'testFeat1': createLitType(CategoryLabel, {'vocab': ['a', 'b', 'c']})
}
},
'penguin_dev': {
'size': 10,
'spec': {
'body_mass_g': createLitType('Scalar', {
'body_mass_g': createLitType(Scalar, {
'step': 1,
}),
'culmen_depth_mm': createLitType('Scalar', {
'culmen_depth_mm': createLitType(Scalar, {
'step': 1,
}),
'culmen_length_mm': createLitType('Scalar', {
'culmen_length_mm': createLitType(Scalar, {
'step': 1,
}),
'flipper_length_mm': createLitType('Scalar', {
'flipper_length_mm': createLitType(Scalar, {
'step': 1,
}),
'island': createLitType(
'CategoryLabel', {'vocab': ['Biscoe', 'Dream', 'Torgersen']}),
'sex': createLitType('CategoryLabel', {'vocab': ['female', 'male']}),
CategoryLabel, {'vocab': ['Biscoe', 'Dream', 'Torgersen']}),
'sex': createLitType(CategoryLabel, {'vocab': ['female', 'male']}),
'species': createLitType(
'CategoryLabel', {'vocab': ['Adelie', 'Chinstrap', 'Gentoo']}),
'isAlive': createLitType('BooleanLitType', {'required': false})
CategoryLabel, {'vocab': ['Adelie', 'Chinstrap', 'Gentoo']}),
'isAlive': createLitType(BooleanLitType, {'required': false})
}
}
},
'generators': {
'word_replacer': {
'configSpec': {
'Substitutions': createLitType('TextSegment', {
'Substitutions': createLitType(TextSegment, {
'default': 'great -> terrible'
}),
},
Expand Down
51 changes: 31 additions & 20 deletions lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {html, TemplateResult} from 'lit';
import {unsafeHTML} from 'lit/directives/unsafe-html.js';

import {marked} from 'marked';
import {LitName, LitType, LitTypeTypesList, LitTypeWithParent, MulticlassPreds, REGISTRY} from './lit_types';
import {LitName, LitType, LitTypeTypesList, LitTypeWithParent, MulticlassPreds, LIT_TYPES_REGISTRY} from './lit_types';
import {FacetMap, LitMetadata, ModelInfoMap, SerializedLitMetadata, SerializedSpec, Spec} from './types';

/** Calculates the mean for a list of numbers */
Expand Down Expand Up @@ -111,42 +111,52 @@ export function getTypes(litNames: LitName|LitName[]) : any {
litNames = [litNames];
}

return litNames.map(litName => REGISTRY[litName]);
return litNames.map(litName => LIT_TYPES_REGISTRY[litName]);
}

/**
* Creates and returns a new LitType instance.
* @param typeName: The name of the desired LitType.
* @param litTypeConstructor: A constructor for the LitType instance.
* @param constructorParams: A dictionary of properties to set on the LitType.
* For example, {'show_in_data_table': true}.
*/
export function createLitType(
typeName: LitName, constructorParams: {[key: string]: unknown} = {}) {
const litType = REGISTRY[typeName];
*
* We use this helper instead of directly creating a new T(), because this
* allows creation of LitTypes dynamically from the metadata returned from the
* server via the `/get_info` API, and allows updating class properties on
* creation time.
*/
export function createLitType<T>(
litTypeConstructor: new () => T,
constructorParams: {[key: string]: unknown} = {}): T {
const litType = new litTypeConstructor();
// Temporarily make this LitType generic to set properties dynamically.
// tslint:disable-next-line:no-any
const newType = new (litType as any)();
newType.__name__ = typeName;
const genericLitType = litType as any;
// TODO(b/162269499): Consider removing __name__ property.
genericLitType.__name__ = litTypeConstructor.name;

for (const key in constructorParams) {
if (key in newType) {
newType[key] = constructorParams[key];
if (key in genericLitType) {
genericLitType[key] = constructorParams[key];
} else {
throw new Error(
`Attempted to set unrecognized property ${key} on ${typeName}.`);
throw new Error(`Attempted to set unrecognized property ${key} on ${
genericLitType.__name__}.`);
}
}

return newType;
return genericLitType as T;
}

/**
* Converts serialized LitTypes within a Spec into LitType instances.
*/
export function deserializeLitTypesInSpec(serializedSpec: SerializedSpec): Spec {
export function deserializeLitTypesInSpec(serializedSpec: SerializedSpec):
Spec {
const typedSpec: Spec = {};
for (const key of Object.keys(serializedSpec)) {
typedSpec[key] =
createLitType(serializedSpec[key].__name__, serializedSpec[key] as {});
typedSpec[key] = createLitType(
LIT_TYPES_REGISTRY[serializedSpec[key].__name__],
serializedSpec[key] as {});
}
return typedSpec;
}
Expand All @@ -157,16 +167,17 @@ export function deserializeLitTypesInSpec(serializedSpec: SerializedSpec): Spec
export function cloneSpec(spec: Spec): Spec {
const newSpec: Spec = {};
for (const [key, fieldSpec] of Object.entries(spec)) {
newSpec[key] = createLitType(fieldSpec.__name__, fieldSpec as {});
newSpec[key] =
createLitType(LIT_TYPES_REGISTRY[fieldSpec.__name__], fieldSpec as {});
}
return newSpec;
}

/**
* Converts serialized LitTypes within the LitMetadata into LitType instances.
*/
export function deserializeLitTypesInLitMetadata(metadata: SerializedLitMetadata) :
LitMetadata {
export function deserializeLitTypesInLitMetadata(
metadata: SerializedLitMetadata): LitMetadata {
for (const model of Object.keys(metadata.models)) {
metadata.models[model].spec.input =
deserializeLitTypesInSpec(metadata.models[model].spec.input);
Expand Down
Loading

0 comments on commit a36a936

Please sign in to comment.