Skip to content

Commit

Permalink
Change the format of saliency clustering results as rows by clusters …
Browse files Browse the repository at this point in the history
…and add selection interaction with the data table

PiperOrigin-RevId: 496977258
  • Loading branch information
Googler committed Dec 21, 2022
1 parent 17bb31c commit 3a3aad3
Showing 1 changed file with 90 additions and 138 deletions.
228 changes: 90 additions & 138 deletions lit_nlp/client/modules/salience_clustering_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ import {observable} from 'mobx';

import {app} from '../core/app';
import {LitModule} from '../core/lit_module';
import {LegendType} from '../elements/color_legend';
import {InterpreterClick} from '../elements/interpreter_controls';
import {SortableTemplateResult, TableData} from '../elements/table';
import {TableData} from '../elements/table';
import {CategoryLabel, FieldMatcher, TokenSalience} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {CallConfig, IndexedInput, Input, ModelInfoMap} from '../lib/types';
import {CallConfig, IndexedInput, ModelInfoMap} from '../lib/types';
import {cloneSpec, createLitType, findSpecKeys} from '../lib/utils';
import {SalienceCmap, UnsignedSalienceCmap} from '../services/color_service';
import {ColumnData} from '../services/data_service';
import {DataService} from '../services/services';

Expand All @@ -39,7 +41,6 @@ import {styles} from './salience_clustering_module.css';
const SALIENCE_MAPPER_KEY = 'Salience Mapper';
const N_CLUSTERS_KEY = 'Number of Clusters';
const SALIENCE_CLUSTERING_INTERPRETER_NAME = 'Salience Clustering';
const RESULT_TOP_TOKENS = 'Top Tokens';
const REUSE_CLUSTERING = 'reuse_clustering';

interface ModuleState {
Expand All @@ -56,9 +57,9 @@ interface ClusteringState {
isLoading: boolean;
}

// Clustering result for a single piece of text.
// Clustering assignment result for each data point.
interface ClusterInfo {
example: Input;
exampleId: string;
clusterId: number;
}

Expand All @@ -73,11 +74,6 @@ interface TopTokenInfo {
weight: number;
}

// Indicators which expandable areas are open or closed.
interface VisToggles {
[name: string]: boolean;
}

/**
* A LIT module that renders salience clustering results.
*/
Expand Down Expand Up @@ -111,10 +107,6 @@ export class SalienceClusteringModule extends LitModule {
},
};

@observable private readonly expanded: VisToggles = {
[RESULT_TOP_TOKENS]: true,
};

override firstUpdated() {
const state: ModuleState = {
dataColumns: [],
Expand Down Expand Up @@ -183,8 +175,8 @@ export class SalienceClusteringModule extends LitModule {
const dataMap: ColumnData = new Map();

for (let i = 0; i < clusterIds.length; i++) {
const clusterInfo: ClusterInfo = {
example: inputs[i].data, clusterId: clusterIds[i]};
const clusterInfo:
ClusterInfo = {exampleId: inputs[i].id, clusterId: clusterIds[i]};
clusterInfos.push(clusterInfo);
dataMap.set(inputs[i].id, clusterInfo.clusterId.toString());
}
Expand All @@ -194,7 +186,9 @@ export class SalienceClusteringModule extends LitModule {
this.dataService.addColumn(
dataMap, SALIENCE_CLUSTERING_INTERPRETER_NAME, featName, dataType,
'Interpreter', localGetValueFn);
this.statusMessage = `New column added: ${featName}.`;
const numDataPoints = clusterInfos.length;
this.statusMessage = `Clustered ${numDataPoints} datapoints ` +
`(Generated column name: ${featName}).`;

// Store top tokens.
this.state.clusteringState.topTokenInfosByClusters[gradKey] = [];
Expand Down Expand Up @@ -223,127 +217,73 @@ export class SalienceClusteringModule extends LitModule {
<lit-spinner size=${24} color="var(--app-secondary-color)">
</lit-spinner>
</div>`;
// clang format on
}

renderDataTable() {
const renderSingleGradKeyClusterInfos =
(gradKey: string, clusterInfos: ClusterInfo[]) => {
const rows: TableData[] = clusterInfos.map((clusterInfo) => {
const row: string[] = [];
for (const dataColumn of this.state.dataColumns) {
row.push(clusterInfo.example[dataColumn]);
}
row.push(clusterInfo.clusterId.toString());
return row;
});

// clang-format off
return html`
<div class="grad-key-row">
<div class="grad-key-label">${gradKey}</div>
<lit-data-table
.columnNames=${
[...this.state.dataColumns, 'cluster ID']}
.data=${rows}
searchEnabled
selectionEnabled
paginationEnabled
></lit-data-table>
</div>`;
// clang format on
};

const renderClusterInfos =
(gradKeyClusterInfos: {[name: string]: ClusterInfo[]}) => {
const gradKeys = Object.keys(gradKeyClusterInfos);
// clang-format off
return html`
${
gradKeys.map(
gradKey => renderSingleGradKeyClusterInfos(
gradKey, gradKeyClusterInfos[gradKey]))}
${
this.state.clusteringState.isLoading ? this.renderSpinner() :
null}`;
// clang format on
};

return renderClusterInfos(this.state.clusteringState.clusterInfos);
// clang-format on
}

// Render a table that contains all top tokens and their weights per cluster.
private renderSingleGradKeyTopTokenInfos(gradKey: string,
topTokenInfosByClusters: TopTokenInfosByCluster[]) {
const clusterCount = topTokenInfosByClusters.length;
const maxTopTokenCount = Math.max(...topTokenInfosByClusters.map(
topTokenInfosByCluster => topTokenInfosByCluster.topTokenInfos.length
));

const columnNames: string[] = [];

for (let clusterId = 0; clusterId < clusterCount; clusterId++) {
columnNames.push(`Cluster ${clusterId}`);
}

const rows: TableData[] = [];
const tokenWeightStyle = styleMap({
'display': 'flex',
'flex-direction': 'row',
'justify-content': 'space-between',
'width': '100%',
// Render token chip.
// TODO(b/204887716): This needs to be replaced with a custom element.
private renderToken(token: string, weight: number, cmap: SalienceCmap,
gradKey: string) {
const tokenStyle = styleMap({
'background-color': cmap.bgCmap(weight),
'border-radius': '2px',
'color': cmap.textCmap(weight),
'margin': '2px 5px',
'padding': '1px 3px'
});

for (let exampleIdx = 0; exampleIdx < maxTopTokenCount;
exampleIdx++) {
const row: SortableTemplateResult[] = [];
return html`
<div class="salient-token-for-cluster" style=${tokenStyle}
title=${weight.toPrecision(3)} data-gradkey=${gradKey}>
${token}
</div>`;
}

for (let clusterId = 0; clusterId < clusterCount; clusterId++) {
const topTokenInfos = topTokenInfosByClusters[clusterId];
if (topTokenInfos.topTokenInfos.length < maxTopTokenCount) {
row.push({template: html``, value: 0});
} else {
const {token, weight} = topTokenInfos.topTokenInfos[exampleIdx];
row.push({
template: html`
<div style=${tokenWeightStyle}>
<span>${token}</span>
<span>(${weight.toFixed(2)})</span>
</div>
`,
value: weight
});
}
}
rows.push(row);
}
// Render a table that lists clusters with their top tokens.
private renderSingleGradKeyTopTokenInfos(
gradKey: string, topTokenInfosByClusters: TopTokenInfosByCluster[],
clusterInfos: ClusterInfo[]) {
const unsignedCmap = new UnsignedSalienceCmap();

const rowsByClusters: TableData[] =
topTokenInfosByClusters.map((topTokenInfos, clusterIdx) => {
const tokensDom = topTokenInfos.topTokenInfos.map(
tokenInfo => this.renderToken(
tokenInfo.token, tokenInfo.weight, unsignedCmap, gradKey));
return [clusterIdx, html`${tokensDom}`];
});

const onSelectClusters = (clusterIdxs: number[]) => {
const dataPointIds: string[] = clusterInfos
.filter(clusterInfo => clusterIdxs.includes(clusterInfo.clusterId))
.map(clusterInfo => clusterInfo.exampleId);
this.selectionService.selectIds(dataPointIds, this);
};

// clang-format off
return html`
<div class="grad-key-row">
<div class="grad-key-label">${gradKey}</div>
<lit-data-table
.columnNames=${columnNames}
.data=${rows}
searchEnabled
.columnNames=${['Cluster Index', 'Top Tokens']}
.data=${rowsByClusters}
selectionEnabled
paginationEnabled
.onSelect=${onSelectClusters}
></lit-data-table>
</div>`;
// clang format on
// clang-format on
}

renderTopTokens() {
const gradKeyTopTokenInfos =
this.state.clusteringState.topTokenInfosByClusters;
const gradKeys = Object.keys(gradKeyTopTokenInfos);
const {topTokenInfosByClusters, clusterInfos} = this.state.clusteringState;
const gradKeys = Object.keys(topTokenInfosByClusters);
// clang-format off
return html`
${gradKeys.map(
gradKey => this.renderSingleGradKeyTopTokenInfos(
gradKey, gradKeyTopTokenInfos[gradKey]))}
gradKey, topTokenInfosByClusters[gradKey], clusterInfos[gradKey]))}
${this.state.clusteringState.isLoading ? this.renderSpinner() : null}`;
// clang format on
// clang-format on
}

private getSalienceInterpreterNames() {
Expand Down Expand Up @@ -379,7 +319,6 @@ export class SalienceClusteringModule extends LitModule {
this.state.salienceConfigs[name] = settings;
};

// clang-format off
const renderInterpreterControls = (name: string) => {
const spec = this.appState.metadata.interpreters[name].configSpec;
const clonedSpec = cloneSpec(spec);
Expand All @@ -394,6 +333,7 @@ export class SalienceClusteringModule extends LitModule {
if (Object.keys(clonedSpec).length === 0) {
return;
}
// clang-format off
return html`
<lit-interpreter-controls
.spec=${clonedSpec}
Expand All @@ -404,30 +344,22 @@ export class SalienceClusteringModule extends LitModule {
moduleControlsApplyCallback :
methodControlsApplyCallback}>
</lit-interpreter-controls>`;
// clang-format on
};
// Always show the clustering config first.
const interpreters: string[] = [SALIENCE_CLUSTERING_INTERPRETER_NAME,
...this.getSalienceInterpreterNames()];
const expansionArea = (resultName: string) => {
let content = html``;
const interpreters: string[] = [
SALIENCE_CLUSTERING_INTERPRETER_NAME,
...this.getSalienceInterpreterNames()
];

if (resultName === RESULT_TOP_TOKENS) {
content = this.renderTopTokens();
}
return html`
<expansion-panel
.label=${resultName}
?expanded=${this.expanded[resultName]}>
${content}
</expansion-panel>`;
};
// clang-format off
return html`
<div class="controls-and-results-container">
<div class="clustering-controls-container">
${interpreters.map(renderInterpreterControls)}
</div>
<div class="clustering-results-container">
${expansionArea(RESULT_TOP_TOKENS)}
${this.renderTopTokens()}
</div>
</div>`;
// clang-format on
Expand All @@ -441,11 +373,31 @@ export class SalienceClusteringModule extends LitModule {
private renderSelectionWarning() {
// clang-format off
return html`
<div class="selection-warning">
<span class="selection-warning">
Please select no datapoint (for entire dataset) or >= 2 datapoints to
compute clusters.
</div>`;
// clang format on
</span>`;
// clang-format on
}

private renderColorLegend() {
const colorMap = new UnsignedSalienceCmap();

// TODO(b/263270935): Use a toColorLegend method to avoid D3-like style.
function scale(val: number) {
return colorMap.bgCmap(val);
}
scale.domain = () => colorMap.colorScale.domain();

// clang-format off
return html`
<div class="color-legend-container">
<color-legend legendType=${LegendType.SEQUENTIAL}
.scale=${scale}
numBlocks=5}>
</color-legend>
</div>`;
// clang-format on
}

override renderImpl() {
Expand All @@ -458,12 +410,12 @@ export class SalienceClusteringModule extends LitModule {
<div class="module-footer">
<p class="module-status">
${this.canRunClustering ? null : this.renderSelectionWarning()}
${this.statusMessage}
</p>
<p class="module-status">${this.statusMessage}</p>
${this.renderColorLegend()}
</div>
</div>`;
// clang format on

// clang-format on
}

static override shouldDisplayModule(modelSpecs: ModelInfoMap) {
Expand Down

0 comments on commit 3a3aad3

Please sign in to comment.