Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of Sparsemax activation #20558

Merged
merged 9 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ dist/**
examples/**/*.jpg
.python-version
.coverage
*coverage.xml
*coverage.xml
.ruff_cache
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from keras.src.activations.activations import softplus
from keras.src.activations.activations import softsign
from keras.src.activations.activations import sparse_plus
from keras.src.activations.activations import sparsemax
from keras.src.activations.activations import squareplus
from keras.src.activations.activations import tanh
from keras.src.activations.activations import tanh_shrink
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from keras.src.ops.nn import softsign
from keras.src.ops.nn import sparse_categorical_crossentropy
from keras.src.ops.nn import sparse_plus
from keras.src.ops.nn import sparsemax
from keras.src.ops.nn import squareplus
from keras.src.ops.nn import tanh_shrink
from keras.src.ops.numpy import abs
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@
from keras.src.ops.nn import softsign
from keras.src.ops.nn import sparse_categorical_crossentropy
from keras.src.ops.nn import sparse_plus
from keras.src.ops.nn import sparsemax
from keras.src.ops.nn import squareplus
from keras.src.ops.nn import tanh_shrink
1 change: 1 addition & 0 deletions keras/api/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from keras.src.activations.activations import softplus
from keras.src.activations.activations import softsign
from keras.src.activations.activations import sparse_plus
from keras.src.activations.activations import sparsemax
from keras.src.activations.activations import squareplus
from keras.src.activations.activations import tanh
from keras.src.activations.activations import tanh_shrink
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from keras.src.ops.nn import softsign
from keras.src.ops.nn import sparse_categorical_crossentropy
from keras.src.ops.nn import sparse_plus
from keras.src.ops.nn import sparsemax
from keras.src.ops.nn import squareplus
from keras.src.ops.nn import tanh_shrink
from keras.src.ops.numpy import abs
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@
from keras.src.ops.nn import softsign
from keras.src.ops.nn import sparse_categorical_crossentropy
from keras.src.ops.nn import sparse_plus
from keras.src.ops.nn import sparsemax
from keras.src.ops.nn import squareplus
from keras.src.ops.nn import tanh_shrink
2 changes: 2 additions & 0 deletions keras/src/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from keras.src.activations.activations import softplus
from keras.src.activations.activations import softsign
from keras.src.activations.activations import sparse_plus
from keras.src.activations.activations import sparsemax
from keras.src.activations.activations import squareplus
from keras.src.activations.activations import tanh
from keras.src.activations.activations import tanh_shrink
Expand Down Expand Up @@ -59,6 +60,7 @@
mish,
log_softmax,
log_sigmoid,
sparsemax,
}

ALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS}
Expand Down
25 changes: 25 additions & 0 deletions keras/src/activations/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,28 @@ def log_softmax(x, axis=-1):
axis: Integer, axis along which the softmax is applied.
"""
return ops.log_softmax(x, axis=axis)


@keras_export(["keras.activations.sparsemax"])
def sparsemax(x, axis=-1):
"""Sparsemax activation function.

For each batch `i`, and class `j`,
sparsemax activation function is defined as:

`sparsemax(x)[i, j] = max(logits[i, j] - τ(logits[i, :]), 0).`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logits -> x


Args:
logits: Input tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here that's x

axis: `int`, axis along which the sparsemax operation is applied.

Returns:
A tensor, output of sparsemax transformation. Has the same type and
shape as `logits`.

Reference:

- [Martins et.al., 2016](https://arxiv.org/abs/1602.02068)
"""
x = backend.convert_to_tensor(x)
return ops.sparsemax(x, axis)
49 changes: 49 additions & 0 deletions keras/src/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,55 @@ def test_linear(self):
x_int32 = np.random.randint(-10, 10, (10, 5)).astype(np.int32)
self.assertAllClose(x_int32, activations.linear(x_int32))

def test_sparsemax(self):
# result check with 1d
x_1d = np.linspace(1, 12, num=12)
expected_result = np.zeros_like(x_1d)
expected_result[-1] = 1.0
self.assertAllClose(expected_result, activations.sparsemax(x_1d))

# result check with 2d
x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)
expected_result = np.zeros_like(x_2d)
expected_result[:, -1] = 1.0
self.assertAllClose(expected_result, activations.sparsemax(x_2d))

# result check with 3d
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
expected_result = np.zeros_like(x_3d)
expected_result[:, :, -1] = 1.0
self.assertAllClose(expected_result, activations.sparsemax(x_3d))

# result check with axis=-2 with 2d input
x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)
expected_result = np.zeros_like(x_2d)
expected_result[-1, :] = 1.0
self.assertAllClose(
expected_result, activations.sparsemax(x_2d, axis=-2)
)

# result check with axis=-2 with 3d input
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
expected_result = np.ones_like(x_3d)
self.assertAllClose(
expected_result, activations.sparsemax(x_3d, axis=-2)
)

# result check with axis=-3 with 3d input
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
expected_result = np.zeros_like(x_3d)
expected_result[-1, :, :] = 1.0
self.assertAllClose(
expected_result, activations.sparsemax(x_3d, axis=-3)
)

# result check with axis=-3 with 4d input
x_4d = np.linspace(1, 12, num=12).reshape(-1, 1, 1, 2)
expected_result = np.ones_like(x_4d)
self.assertAllClose(
expected_result, activations.sparsemax(x_4d, axis=-3)
)

def test_get_method(self):
obj = activations.get("relu")
self.assertEqual(obj, activations.relu)
Expand Down
18 changes: 18 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,24 @@ def log_softmax(x, axis=-1):
return jnn.log_softmax(x, axis=axis)


def sparsemax(logits, axis=-1):
# Sort logits along the specified axis in descending order
logits = convert_to_tensor(logits)
logits_sorted = -1.0 * jnp.sort(logits * -1.0, axis=axis)
logits_cumsum = jnp.cumsum(logits_sorted, axis=axis) # find cumulative sum
r = jnp.arange(1, logits.shape[axis] + 1) # Determine the sparsity
r_shape = [1] * logits.ndim
r_shape[axis] = -1 # Broadcast to match the target axis
r = r.reshape(r_shape)
support = logits_sorted - (logits_cumsum - 1) / r > 0
# Find the threshold
k = jnp.sum(support, axis=axis, keepdims=True)
logits_cumsum_safe = jnp.where(support, logits_cumsum, 0.0)
tau = (jnp.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
output = jnp.maximum(logits - tau, 0.0)
return output


def _convert_to_spatial_operand(
x,
num_spatial_dims,
Expand Down
18 changes: 18 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,24 @@ def log_softmax(x, axis=None):
return x - max_x - logsumexp


def sparsemax(logits, axis=-1):
# Sort logits along the specified axis in descending order
logits = convert_to_tensor(logits)
logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis)
logits_cumsum = np.cumsum(logits_sorted, axis=axis)
r = np.arange(1, logits.shape[axis] + 1)
r_shape = [1] * logits.ndim
r_shape[axis] = -1 # Broadcast to match the target axis
r = r.reshape(r_shape)
support = logits_sorted - (logits_cumsum - 1) / r > 0
# Find the threshold
k = np.sum(support, axis=axis, keepdims=True)
logits_cumsum_safe = np.where(support, logits_cumsum, 0.0)
tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
output = np.maximum(logits - tau, 0.0)
return output


def _convert_to_spatial_operand(
x,
num_spatial_dims,
Expand Down
18 changes: 18 additions & 0 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def log_softmax(x, axis=-1):
return tf.nn.log_softmax(x, axis=axis)


def sparsemax(logits, axis=-1):
# Sort logits along the specified axis in descending order
logits = convert_to_tensor(logits)
logits_sorted = tf.sort(logits, direction="DESCENDING", axis=axis)
logits_cumsum = tf.cumsum(logits_sorted, axis=axis)
r = tf.range(1, tf.shape(logits)[axis] + 1, dtype=logits.dtype)
r_shape = [1] * len(logits.shape)
r_shape[axis] = -1 # Broadcast to match the target axis
r = tf.reshape(r, r_shape) # Reshape for broadcasting
support = logits_sorted - (logits_cumsum - 1) / r > 0
# Find the threshold
logits_cumsum_safe = tf.where(support, logits_cumsum, 0.0)
k = tf.reduce_sum(tf.cast(support, logits.dtype), axis=axis, keepdims=True)
tau = (tf.reduce_sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
output = tf.maximum(logits - tau, 0.0)
return output


def _transpose_spatial_inputs(inputs):
num_spatial_dims = len(inputs.shape) - 2
# Tensorflow pooling does not support `channels_first` format, so
Expand Down
22 changes: 22 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,28 @@ def log_softmax(x, axis=-1):
return cast(output, dtype)


def sparsemax(logits, axis=-1):
# Sort logits along the specified axis in descending order
logits = convert_to_tensor(logits)
logits_sorted, _ = torch.sort(logits, dim=axis, descending=True)
logits_cumsum = torch.cumsum(logits_sorted, dim=axis)
r = torch.arange(
1, logits.size(axis) + 1, device=logits.device, dtype=logits.dtype
)
r_shape = [1] * logits.ndim
r_shape[axis] = -1 # Broadcast to match the target axis
r = r.view(r_shape)
support = logits_sorted - (logits_cumsum - 1) / r > 0
# Find the threshold
k = torch.sum(support, dim=axis, keepdim=True)
logits_cumsum_safe = torch.where(
support, logits_cumsum, torch.tensor(0.0, device=logits.device)
)
tau = (torch.sum(logits_cumsum_safe, dim=axis, keepdim=True) - 1) / k
output = torch.clamp(logits - tau, min=0.0)
return output


def _compute_padding_length(
input_length, kernel_length, stride, dilation_rate=1
):
Expand Down
42 changes: 42 additions & 0 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,48 @@ def log_softmax(x, axis=-1):
return backend.nn.log_softmax(x, axis=axis)


class Sparsemax(Operation):
def __init__(self, axis=-1):
super().__init__()
self.axis = axis

def call(self, x):
return backend.nn.sparsemax(x, axis=self.axis)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)


@keras_export(["keras.ops.sparsemax", "keras.ops.nn.sparsemax"])
def sparsemax(x, axis=-1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments for this docstring

"""Sparsemax activation function.

For each batch `i`, and class `j`,
sparsemax activation function is defined as:

`sparsemax(x)[i, j] = max(logits[i, j] - τ(logits[i, :]), 0).`

Args:
logits: Input tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's x

axis: `int`, axis along which the sparsemax operation is applied.

Returns:
A tensor, output of sparsemax transformation. Has the same type and
shape as `logits`.

Example:

>>> x = np.array([-1., 0., 1.])
>>> x_sparsemax = keras.ops.sparsemax(x)
>>> print(x_sparsemax)
array([0., 0., 1.], shape=(3,), dtype=float64)

"""
if any_symbolic_tensors((x,)):
return Sparsemax(axis).symbolic_call(x)
return backend.nn.sparsemax(x, axis=axis)


class MaxPool(Operation):
def __init__(
self,
Expand Down
15 changes: 15 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def test_log_softmax(self):
self.assertEqual(knn.log_softmax(x, axis=1).shape, (None, 2, 3))
self.assertEqual(knn.log_softmax(x, axis=-1).shape, (None, 2, 3))

def test_sparsemax(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.sparsemax(x).shape, (None, 2, 3))

def test_max_pool(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
Expand Down Expand Up @@ -861,6 +865,10 @@ def test_log_softmax(self):
self.assertEqual(knn.log_softmax(x, axis=1).shape, (1, 2, 3))
self.assertEqual(knn.log_softmax(x, axis=-1).shape, (1, 2, 3))

def test_sparsemax(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.sparsemax(x).shape, (1, 2, 3))

def test_max_pool(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
Expand Down Expand Up @@ -1487,6 +1495,13 @@ def test_log_softmax_correctness_with_axis_tuple(self):
)
self.assertAllClose(normalized_sum_by_axis, 1.0)

def test_sparsemax(self):
x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)
self.assertAllClose(
knn.sparsemax(x),
[0.0, 0.0, 0.0, 0.0, 1.0],
)

def test_max_pool(self):
data_format = backend.config.image_data_format()
# Test 1D max pooling.
Expand Down