diff --git a/lit_nlp/client/modules/scalar_module.ts b/lit_nlp/client/modules/scalar_module.ts index b2737d3f..d6a016bb 100644 --- a/lit_nlp/client/modules/scalar_module.ts +++ b/lit_nlp/client/modules/scalar_module.ts @@ -32,7 +32,7 @@ import {ThresholdChange} from '../elements/threshold_slider'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {D3Selection, formatForDisplay, IndexedInput, ModelInfoMap, ModelSpec, Preds, Spec} from '../lib/types'; import {doesOutputSpecContain, findSpecKeys, getThresholdFromMargin, isLitSubtype} from '../lib/utils'; -import {CalculatedColumnType, CLASSIFICATION_SOURCE_PREFIX, REGRESSION_SOURCE_PREFIX} from '../services/data_service'; +import {CalculatedColumnType, CLASSIFICATION_SOURCE_PREFIX, REGRESSION_SOURCE_PREFIX, SCALAR_SOURCE_PREFIX} from '../services/data_service'; import {FocusData} from '../services/focus_service'; import {ClassificationService, ColorService, DataService, FocusService, GroupService, SelectionService} from '../services/services'; @@ -118,7 +118,8 @@ export class ScalarModule extends LitModule { return true; } else if (col.source.includes(REGRESSION_SOURCE_PREFIX)) { return false; - } else if (col.source.includes(CLASSIFICATION_SOURCE_PREFIX)) { + } else if (col.source.includes(CLASSIFICATION_SOURCE_PREFIX) || + col.source.includes(SCALAR_SOURCE_PREFIX)) { return col.source.includes(this.model); } else { return true; @@ -126,12 +127,6 @@ export class ScalarModule extends LitModule { }); } - @computed - private get scalarModelOutputKeys() { - const outputSpec = this.appState.currentModelSpecs[this.model].spec.output; - return findSpecKeys(outputSpec, 'Scalar'); - } - @computed private get classificationKeys() { const outputSpec = this.appState.currentModelSpecs[this.model].spec.output; @@ -365,8 +360,6 @@ export class ScalarModule extends LitModule { currentInputData, this.model, dataset, ['MulticlassPreds']), this.apiService.getPreds( currentInputData, this.model, dataset, ['RegressionScore']), - this.apiService.getPreds( - currentInputData, this.model, dataset, ['Scalar']), ]); const results = await this.loadLatest('predictionScores', promise); if (results === null) { @@ -374,9 +367,7 @@ export class ScalarModule extends LitModule { } const classificationPreds = results[0]; const regressionPreds = results[1]; - const scalarPreds = results[2]; - if (classificationPreds == null && regressionPreds == null && - scalarPreds == null) { + if (classificationPreds == null && regressionPreds == null) { return; } @@ -386,7 +377,7 @@ export class ScalarModule extends LitModule { // TODO(lit-dev): structure this as a proper IndexedInput, // rather than having 'id' as a regular field. const pred = Object.assign( - {}, classificationPreds[i], scalarPreds[i], regressionPreds[i], + {}, classificationPreds[i], regressionPreds[i], {id: currId}); for (const scalarKey of this.scalarColumnsToPlot) { pred[scalarKey] = this.dataService.getVal(currId, scalarKey); @@ -819,7 +810,6 @@ export class ScalarModule extends LitModule { // clang-format off return html`
- ${this.scalarModelOutputKeys.map(key => this.renderPlot(key, ''))} ${this.classificationKeys.map(key => this.renderClassificationGroup(key))} ${this.scalarColumnsToPlot.map(key => this.renderPlot(key, ''))} diff --git a/lit_nlp/client/services/data_service.ts b/lit_nlp/client/services/data_service.ts index 6bd388f0..784e920d 100644 --- a/lit_nlp/client/services/data_service.ts +++ b/lit_nlp/client/services/data_service.ts @@ -60,6 +60,8 @@ export type ColumnData = Map; export const CLASSIFICATION_SOURCE_PREFIX = 'Classification'; /** Column source prefix for columns from the regression interpreter. */ export const REGRESSION_SOURCE_PREFIX = 'Regression'; +/** Column source prefix for columns from scalar model outputs. */ +export const SCALAR_SOURCE_PREFIX = 'Scalar'; /** * Data service singleton, responsible for maintaining columns of computed data @@ -110,6 +112,7 @@ export class DataService extends LitService { } for (const model of this.appState.currentModels) { this.runRegression(model, this.appState.currentInputData); + this.runScalarPreds(model, this.appState.currentInputData); } }, {fireImmediately: true}); @@ -219,6 +222,33 @@ export class DataService extends LitService { } } + /** + * Run scalar predictions and store results in data service. + */ + private async runScalarPreds(model: string, data: IndexedInput[]) { + const {output} = this.appState.currentModelSpecs[model].spec; + if (findSpecKeys(output, 'Scalar').length === 0) { + return; + } + + const predsPromise = this.apiService.getPreds( + data, model, this.appState.currentDataset, ['Scalar']); + const preds = await predsPromise; + + // Add scalar results as new column to the data service. + if (preds == null || preds.length === 0) { + return; + } + const scalarKeys = Object.keys(preds[0]); + for (const key of scalarKeys) { + const scoreFeatName = this.getColumnName(model, key); + const scores = preds.map(pred => pred[key]); + const dataType = this.appState.createLitType('Scalar', false); + const source = `${SCALAR_SOURCE_PREFIX}:${model}`; + this.addColumnFromList(scores, data, scoreFeatName, dataType, source); + } + } + @action async setValuesForNewDatapoints(datapoints: IndexedInput[]) { // When new datapoints are created, set their data values for each