Skip to content

Commit

Permalink
Add initial salience clustering UI module.
Browse files Browse the repository at this point in the history
The module contains the connection to the interpreter and result rendering. It renders results as a table of (text, cluster id) tuples. Follow-up CLs will add more functionality.

PiperOrigin-RevId: 424288027
  • Loading branch information
eberts-google authored and LIT team committed Jan 26, 2022
1 parent 6ec9a6c commit 8f3c26c
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 2 deletions.
10 changes: 9 additions & 1 deletion lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from lit_nlp.components import pca
from lit_nlp.components import pdp
from lit_nlp.components import projection
from lit_nlp.components import salience_clustering
from lit_nlp.components import scrambler
from lit_nlp.components import tcav
from lit_nlp.components import thresholder
Expand Down Expand Up @@ -458,23 +459,30 @@ def __init__(
'paired': metrics.MulticlassPairedMetrics(),
'bleu': metrics.CorpusBLEU(),
})
self._interpreters = {
gradient_map_interpreters = {
'Grad L2 Norm': gradient_maps.GradientNorm(),
'Grad ⋅ Input': gradient_maps.GradientDotInput(),
'Integrated Gradients': gradient_maps.IntegratedGradients(),
'LIME': lime_explainer.LIME(),
}
# pyformat: disable
self._interpreters = {
'Model-provided salience': model_salience.ModelSalience(self._models),
'counterfactual explainer': lemon_explainer.LEMON(),
'tcav': tcav.TCAV(),
'thresholder': thresholder.Thresholder(),
'nearest neighbors': nearest_neighbors.NearestNeighbors(),
'metrics': metrics_group,
'pdp': pdp.PdpInterpreter(),
'salience clustering': salience_clustering.SalienceClustering(
gradient_map_interpreters),
# Embedding projectors expose a standard interface, but get special
# handling so we can precompute the projections if requested.
'pca': projection.ProjectionManager(pca.PCAModel),
'umap': projection.ProjectionManager(umap.UmapModel),
}
# pyformat: enable
self._interpreters.update(gradient_map_interpreters)

# Information on models, datasets, and other components.
self._info = self._build_metadata()
Expand Down
2 changes: 2 additions & 0 deletions lit_nlp/client/default/layout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import {MetricsModule} from '../modules/metrics_module';
import {MultilabelModule} from '../modules/multilabel_module';
import {PdpModule} from '../modules/pdp_module';
import {RegressionModule} from '../modules/regression_module';
import {SalienceClusteringModule} from '../modules/salience_clustering_module';
import {SalienceMapModule} from '../modules/salience_map_module';
import {ScalarModule} from '../modules/scalar_module';
import {SequenceSalienceModule} from '../modules/sequence_salience_module';
Expand Down Expand Up @@ -108,6 +109,7 @@ export const LAYOUTS: LitComponentLayouts = {
SequenceSalienceModule,
AttentionModule,
],
'Clustering': [SalienceClusteringModule],
'Metrics': [
MetricsModule,
ConfusionMatrixModule,
Expand Down
22 changes: 22 additions & 0 deletions lit_nlp/client/modules/salience_clustering_module.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.clustering-salience-mapper {
padding: 8px;
}

.grad-key-label {
padding: 4px;
text-align: start;
vertical-align: baseline;
width: 150px;
}

.grad-key-row {
border-top: 1px solid #e8eaed;
border-bottom: 1px solid #e8eaed;
padding: 8px;
}

.selection-warning {
padding: 4px;
position: relative; /* to allow overlay div */
vertical-align: top;
}
248 changes: 248 additions & 0 deletions lit_nlp/client/modules/salience_clustering_module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
/**
* @license
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import '../elements/line_chart';
import '../elements/bar_chart';

import {html} from 'lit';
// tslint:disable:no-new-decorators
import {customElement} from 'lit/decorators';
import {observable} from 'mobx';

import {LitModule} from '../core/lit_module';
import {TableData} from '../elements/table';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {CallConfig, Input, ModelInfoMap, Spec} from '../lib/types';

import {styles} from './salience_clustering_module.css';

interface SalienceMapperToClusteringState {
[salienceMapper: string]: ClusteringState;
}

interface ClusteringState {
dataColumns: string[];
clusterInfos: {[name: string]: ClusterInfo[]};
isLoading: boolean;
config?: CallConfig;
}

// Clustering result for a single piece of text.
interface ClusterInfo {
example: Input;
clusterId: number;
}

/**
* A LIT module that renders salience clustering results.
*/
@customElement('salience-clustering-module')
export class SalienceClusteringModule extends LitModule {
static override title = 'Salience Clustering Results';
static override numCols = 1;
// TODO(b/215497716): Get the salience mappers from the interpreter component
// and let the user select the wanted one.
private readonly salienceMapper = 'Grad L2 Norm';
static override template = (model = '', selectionServiceIndex = 0) => {
// clang-format off
return html`
<salience-clustering-module
model=${model} selectionServiceIndex=${selectionServiceIndex}>
</salience-clustering-module>`;
// clang format on
};

static override get styles() {
return [sharedStyles, styles];
}

// Mapping from salience mapper to clustering results.
@observable private state: SalienceMapperToClusteringState = {};

override firstUpdated() {
const state: SalienceMapperToClusteringState = {};
state[this.salienceMapper] = {
dataColumns: [],
clusterInfos: {},
isLoading: false,
config: {'salience_mapper': this.salienceMapper},
};
this.state = state;
}

private runInterpreterDefault() {
return this.runInterpreter(this.salienceMapper);
}

private async runInterpreter(salienceMapper: string) {
this.state[salienceMapper].clusterInfos = {};
const input = this.selectionService.selectedOrAllInputData;
if (!this.canRunClustering) {
return;
}

this.state[salienceMapper].isLoading = true;
const promise = this.apiService.getInterpretations(
input, this.model, this.appState.currentDataset, 'salience clustering',
this.state[salienceMapper].config, `Running ${salienceMapper}`);
const clusteringResult =
await this.loadLatest(`interpretations-${salienceMapper}`, promise);
this.state[salienceMapper].isLoading = false;
this.state[salienceMapper].dataColumns = Object.keys(input[0].data);

for (const gradKey of Object.keys(clusteringResult['cluster_ids'])) {
const clusterInfos: ClusterInfo[] = [];
const clusterIds = clusteringResult['cluster_ids'][gradKey];

for (let i = 0; i < clusterIds.length; i++) {
const clusterInfo: ClusterInfo = {
example: input[i].data, clusterId: clusterIds[i]};
clusterInfos.push(clusterInfo);
}
this.state[salienceMapper].clusterInfos[gradKey] = clusterInfos;
}
}

renderSpinner() {
// clang-format off
return html`
<div class="spinner-container">
<lit-spinner size=${24} color="var(--app-secondary-color)">
</lit-spinner>
</div>`;
// clang format on
}

renderClusteringState(salienceMapper: string) {
const renderSingleGradKeyClusterInfos =
(gradKey: string, clusterInfos: ClusterInfo[]) => {
const rows: TableData[] = clusterInfos.map((clusterInfo) => {
const row: string[] = [];
for (const dataColumn of this.state[salienceMapper].dataColumns) {
row.push(clusterInfo.example[dataColumn]);
}
row.push(clusterInfo.clusterId.toString());
return row;
});

// clang-format off
return html`
<div class="grad-key-row">
<div class="grad-key-label">${gradKey}</div>
<lit-data-table
.columnNames=${
[...this.state[salienceMapper].dataColumns, 'cluster ID']}
.data=${rows}
searchEnabled
selectionEnabled
paginationEnabled
></lit-data-table>
</div>`;
// clang format on
};

const renderClusterInfos =
(gradKeyClusterInfos: {[name: string]: ClusterInfo[]}) => {
const gradKeys = Object.keys(gradKeyClusterInfos);
// clang-format off
return html`
${
gradKeys.map(
gradKey => renderSingleGradKeyClusterInfos(
gradKey, gradKeyClusterInfos[gradKey]))}
${
this.state[salienceMapper].isLoading ? this.renderSpinner() :
null}`;
// clang format on
};

return renderClusterInfos(this.state[salienceMapper].clusterInfos);
}

renderTable() {
if (this.state[this.salienceMapper] == null ||
this.state[this.salienceMapper].clusterInfos == null) {
return html`<div>Nothing to show.</div>`;
}
// clang-format off
return html`
<div class="clustering-salience-mapper">${this.salienceMapper}</div>
${this.renderClusteringState(this.salienceMapper)}`;
// clang format on
}

private get canRunClustering() {
const input = this.selectionService.selectedOrAllInputData;
return (input != null && input.length >= 2);
}

renderSelectionWarning() {
if (this.canRunClustering) {
return html``;
} else {
// clang-format off
return html`
<div class="selection-warning">
Please select no datapoint (for entire dataset) or >= 2 datapoints to
compute clusters.
</div>`;
// clang format on
}
}

renderControls() {
// clang-format off
return html`
<button class='hairline-button'
?disabled="${!this.canRunClustering}"
@click=${this.runInterpreterDefault}>
Compute clusters
</button>
${this.renderSelectionWarning()}`;
// clang format on
}

override render() {
// clang-format off
return html`
<div class='module-container'>
<div class="module-toolbar">
${this.renderControls()}
</div>
<div class='module-results-area'>
${this.renderTable()}
</div>
</div>`;
// clang format on
}

static override shouldDisplayModule(modelSpecs: ModelInfoMap,
datasetSpec: Spec) {
for (const modelInfo of Object.values(modelSpecs)) {
if (modelInfo.interpreters.indexOf('salience clustering') !== -1) {
return true;
}
}
return false;
}
}

declare global {
interface HTMLElementTagNameMap {
'salience-clustering-module': SalienceClusteringModule;
}
}
5 changes: 4 additions & 1 deletion lit_nlp/components/salience_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,13 @@ def run_with_metadata(
the dataset that were used in the clustering.
"""
config = config or {}
# If no specific inputs provided, use the entire dataset.
inputs_to_use = indexed_inputs or dataset.examples

# Find gradient fields to interpret
grad_fields = self.find_fields(model.output_spec())
token_saliencies = self.salience_mappers[
config['salience_mapper']].run_with_metadata(indexed_inputs, model,
config['salience_mapper']].run_with_metadata(inputs_to_use, model,
dataset, model_outputs,
config)

Expand Down

0 comments on commit 8f3c26c

Please sign in to comment.