-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LIT: Refactor instrumented Keras LMs to support TF and Torch
Summary of changes: * Wherever possible, use `keras.ops` for Tensor operations. These functions are designed to create and operate over Tensors from any of the Keras 3 backends. * Identified a common API surface for salience predictions that supports all existing (and likely any future) Keras 3 backends. Functions implementing this API shoudl be named `_pred_{framework}` and return a tuple of the GradNorm and GradDotInput salience scores. * Refactors `KerasSalienceModel._pred()` to perform common operations for data before calling out to backend-specific prediction functions. * Extracts TensorFlow code in `KerasSalienceModel` to a new `_pred_tf()` function. * Implements a `KerasSalienceModel._pred_torch()` function based on the HF implementation in lit_nlp/examples/models/pretrained_lms.py * Provides a stub for `KerasSalienceModel._pred_jax()` with a detailed comment outlining the JAX idiosyncrasies that we need to adapt to in order to support this backend. PiperOrigin-RevId: 622966761
- Loading branch information
1 parent
82abec6
commit 5ee7064
Showing
9 changed files
with
339 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.