-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Addition of Sparsemax activation #20558
Changes from 8 commits
68f6351
3d81e36
4d4eb0c
c5e1e1e
389c6f5
75df0b9
be86a40
9af84aa
fe9dc4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,4 +18,5 @@ dist/** | |
examples/**/*.jpg | ||
.python-version | ||
.coverage | ||
*coverage.xml | ||
*coverage.xml | ||
.ruff_cache |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -617,3 +617,28 @@ def log_softmax(x, axis=-1): | |
axis: Integer, axis along which the softmax is applied. | ||
""" | ||
return ops.log_softmax(x, axis=axis) | ||
|
||
|
||
@keras_export(["keras.activations.sparsemax"]) | ||
def sparsemax(x, axis=-1): | ||
"""Sparsemax activation function. | ||
|
||
For each batch `i`, and class `j`, | ||
sparsemax activation function is defined as: | ||
|
||
`sparsemax(x)[i, j] = max(logits[i, j] - τ(logits[i, :]), 0).` | ||
|
||
Args: | ||
logits: Input tensor. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here that's x |
||
axis: `int`, axis along which the sparsemax operation is applied. | ||
|
||
Returns: | ||
A tensor, output of sparsemax transformation. Has the same type and | ||
shape as `logits`. | ||
|
||
Reference: | ||
|
||
- [Martins et.al., 2016](https://arxiv.org/abs/1602.02068) | ||
""" | ||
x = backend.convert_to_tensor(x) | ||
return ops.sparsemax(x, axis) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -951,6 +951,48 @@ def log_softmax(x, axis=-1): | |
return backend.nn.log_softmax(x, axis=axis) | ||
|
||
|
||
class Sparsemax(Operation): | ||
def __init__(self, axis=-1): | ||
super().__init__() | ||
self.axis = axis | ||
|
||
def call(self, x): | ||
return backend.nn.sparsemax(x, axis=self.axis) | ||
|
||
def compute_output_spec(self, x): | ||
return KerasTensor(x.shape, dtype=x.dtype) | ||
|
||
|
||
@keras_export(["keras.ops.sparsemax", "keras.ops.nn.sparsemax"]) | ||
def sparsemax(x, axis=-1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comments for this docstring |
||
"""Sparsemax activation function. | ||
|
||
For each batch `i`, and class `j`, | ||
sparsemax activation function is defined as: | ||
|
||
`sparsemax(x)[i, j] = max(logits[i, j] - τ(logits[i, :]), 0).` | ||
|
||
Args: | ||
logits: Input tensor. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's x |
||
axis: `int`, axis along which the sparsemax operation is applied. | ||
|
||
Returns: | ||
A tensor, output of sparsemax transformation. Has the same type and | ||
shape as `logits`. | ||
|
||
Example: | ||
|
||
>>> x = np.array([-1., 0., 1.]) | ||
>>> x_sparsemax = keras.ops.sparsemax(x) | ||
>>> print(x_sparsemax) | ||
array([0., 0., 1.], shape=(3,), dtype=float64) | ||
|
||
""" | ||
if any_symbolic_tensors((x,)): | ||
return Sparsemax(axis).symbolic_call(x) | ||
return backend.nn.sparsemax(x, axis=axis) | ||
|
||
|
||
class MaxPool(Operation): | ||
def __init__( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logits -> x