Module riid.losses.sparsemax
This module contains sparsemax-related functions.
Expand source code Browse git
# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
# This code is based on Tensorflow-Addons. THE ORIGINAL CODE HAS BEEN MODIFIED.
# https://www.tensorflow.org/addons/
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains sparsemax-related functions."""
from typing import Optional
import tensorflow as tf
from typeguard import typechecked
def sparsemax(logits, axis: int = -1) -> tf.Tensor:
r"""Sparsemax activation function.
For each batch \( i \), and class \( j \),
compute sparsemax activation function:
$$
\mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0).
$$
See
[From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification
](https://arxiv.org/abs/1602.02068).
Usage:
>>> x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]])
>>> tfa.activations.sparsemax(x)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0., 0., 1.],
[0., 0., 1.]], dtype=float32)>
Args:
logits: `Tensor`
axis: `int`, axis along which the sparsemax operation is applied
Returns:
`Tensor`, output of sparsemax transformation (has the same type and shape as `logits`)
Raises:
`ValueError` when `dim(logits) == 1`
"""
logits = tf.convert_to_tensor(logits, name="logits")
# We need its original shape for shape inference.
shape = logits.get_shape()
rank = shape.rank
is_last_axis = (axis == -1) or (axis == rank - 1)
if is_last_axis:
output = _compute_2d_sparsemax(logits)
output.set_shape(shape)
return output
# If dim is not the last dimension, we have to do a transpose so that we can
# still perform softmax on its last dimension.
# Swap logits' dimension of dim and its last dimension.
rank_op = tf.rank(logits)
axis_norm = axis % rank
logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1))
# Do the actual softmax on its last dimension.
output = _compute_2d_sparsemax(logits)
output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1))
# Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
return output
def _swap_axis(logits, dim_index, last_index, **kwargs):
return tf.transpose(
logits,
tf.concat(
[
tf.range(dim_index),
[last_index],
tf.range(dim_index + 1, last_index),
[dim_index],
],
0,
),
**kwargs,
)
def _compute_2d_sparsemax(logits):
"""Perform the sparsemax operation when axis=-1."""
shape_op = tf.shape(logits)
obs = tf.math.reduce_prod(shape_op[:-1])
dims = shape_op[-1]
# In the paper, they call the logits z.
# The mean(logits) can be substracted from logits to make the algorithm
# more numerically stable. the instability in this algorithm comes mostly
# from the z_cumsum. Substacting the mean will cause z_cumsum to be close
# to zero. However, in practise the numerical instability issues are very
# minor and substacting the mean causes extra issues with inf and nan
# input.
# Reshape to [obs, dims] as it is almost free and means the remanining
# code doesn't need to worry about the rank.
z = tf.reshape(logits, [obs, dims])
# sort z
z_sorted, _ = tf.nn.top_k(z, k=dims)
# calculate k(z)
z_cumsum = tf.math.cumsum(z_sorted, axis=-1)
k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype)
z_check = 1 + k * z_sorted > z_cumsum
# because the z_check vector is always [1,1,...1,0,0,...0] finding the
# (index + 1) of the last `1` is the same as just summing the number of 1.
k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1)
# calculate tau(z)
# If there are inf values or all values are -inf, the k_z will be zero,
# this is mathematically invalid and will also cause the gather_nd to fail.
# Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
# fixed later (see p_safe) by returning p = nan. This results in the same
# behavior as softmax.
k_z_safe = tf.math.maximum(k_z, 1)
indices = tf.stack([tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1)
tau_sum = tf.gather_nd(z_cumsum, indices)
tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype)
# calculate p
p = tf.math.maximum(tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1))
# If k_z = 0 or if z = nan, then the input is invalid
p_safe = tf.where(
tf.expand_dims(
tf.math.logical_or(tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])),
axis=-1,
),
tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)),
p,
)
# Reshape back to original size
p_safe = tf.reshape(p_safe, shape_op)
return p_safe
def sparsemax_loss(logits, sparsemax, labels, name: Optional[str] = None) -> tf.Tensor:
r"""Sparsemax loss function ([1]).
Computes the generalized multi-label classification loss for the sparsemax
function. The implementation is a reformulation of the original loss
function such that it uses the sparsemax probability output instead of the
internal \( \tau \) variable. However, the output is identical to the original
loss function.
[1]: https://arxiv.org/abs/1602.02068
Args:
logits: `Tensor`. Must be one of the following types: `float32`,
`float64`.
sparsemax: `Tensor`. Must have the same type as `logits`.
labels: `Tensor`. Must have the same type as `logits`.
name: name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `logits`.
"""
logits = tf.convert_to_tensor(logits, name="logits")
sparsemax = tf.convert_to_tensor(sparsemax, name="sparsemax")
labels = tf.convert_to_tensor(labels, name="labels")
# In the paper, they call the logits z.
# A constant can be substracted from logits to make the algorithm
# more numerically stable in theory. However, there are really no major
# source numerical instability in this algorithm.
z = logits
# sum over support
# Use a conditional where instead of a multiplication to support z = -inf.
# If z = -inf, and there is no support (sparsemax = 0), a multiplication
# would cause 0 * -inf = nan, which is not correct in this case.
sum_s = tf.where(
tf.math.logical_or(sparsemax > 0, tf.math.is_nan(sparsemax)),
sparsemax * (z - 0.5 * sparsemax),
tf.zeros_like(sparsemax),
)
# - z_k + ||q||^2
q_part = labels * (0.5 * labels - z)
# Fix the case where labels = 0 and z = -inf, where q_part would
# otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
# z = -inf should be consideredself.
# The code below also coveres the case where z = inf. Howeverm in this
# caose the sparsemax will be nan, which means the sum_s will also be nan,
# therefor this case doesn't need addtional special treatment.
q_part_safe = tf.where(
tf.math.logical_and(tf.math.equal(labels, 0), tf.math.is_inf(z)),
tf.zeros_like(z),
q_part,
)
return tf.math.reduce_sum(sum_s + q_part_safe, axis=1)
@tf.function
@tf.keras.utils.register_keras_serializable(package="Addons")
def sparsemax_loss_from_logits(
y_true, logits_pred
) -> tf.Tensor:
y_pred = sparsemax(logits_pred)
loss = sparsemax_loss(logits_pred, y_pred, y_true)
return loss
@tf.keras.utils.register_keras_serializable(package="Addons")
class SparsemaxLoss(tf.keras.losses.Loss):
"""Sparsemax loss function.
Computes the generalized multi-label classification loss for the sparsemax
function.
Because the sparsemax loss function needs both the probability output and
the logits to compute the loss value, `from_logits` must be `True`.
Because it computes the generalized multi-label loss, the shape of both
`y_pred` and `y_true` must be `[batch_size, num_classes]`.
Args:
from_logits: Whether `y_pred` is expected to be a logits tensor. Default
is `True`, meaning `y_pred` is the logits.
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
loss. Default value is `SUM_OVER_BATCH_SIZE`.
name: Optional name for the op
"""
@typechecked
def __init__(
self,
from_logits: bool = True,
reduction: str = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
name: str = "sparsemax_loss",
):
if from_logits is not True:
raise ValueError("from_logits must be True")
super().__init__(name=name, reduction=reduction)
self.from_logits = from_logits
def call(self, y_true, y_pred):
return sparsemax_loss_from_logits(y_true, y_pred)
def get_config(self):
config = {
"from_logits": self.from_logits,
}
base_config = super().get_config()
return {**base_config, **config}
Functions
def sparsemax(logits, axis: int = -1) ‑> tensorflow.python.framework.tensor.Tensor-
Sparsemax activation function.
For each batch i , and class j , compute sparsemax activation function:
\mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0).
See From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification .
Usage:
>>> x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]]) >>> tfa.activations.sparsemax(x) <tf.Tensor: shape=(2, 3), dtype=float32, numpy= array([[0., 0., 1.], [0., 0., 1.]], dtype=float32)>Args
logitsTensoraxisint, axis along which the sparsemax operation is applied
Returns
Tensor, output of sparsemax transformation (has the same type and shape aslogits)Raises
ValueErrorwhendim(logits) == 1Expand source code Browse git
def sparsemax(logits, axis: int = -1) -> tf.Tensor: r"""Sparsemax activation function. For each batch \( i \), and class \( j \), compute sparsemax activation function: $$ \mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0). $$ See [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification ](https://arxiv.org/abs/1602.02068). Usage: >>> x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]]) >>> tfa.activations.sparsemax(x) <tf.Tensor: shape=(2, 3), dtype=float32, numpy= array([[0., 0., 1.], [0., 0., 1.]], dtype=float32)> Args: logits: `Tensor` axis: `int`, axis along which the sparsemax operation is applied Returns: `Tensor`, output of sparsemax transformation (has the same type and shape as `logits`) Raises: `ValueError` when `dim(logits) == 1` """ logits = tf.convert_to_tensor(logits, name="logits") # We need its original shape for shape inference. shape = logits.get_shape() rank = shape.rank is_last_axis = (axis == -1) or (axis == rank - 1) if is_last_axis: output = _compute_2d_sparsemax(logits) output.set_shape(shape) return output # If dim is not the last dimension, we have to do a transpose so that we can # still perform softmax on its last dimension. # Swap logits' dimension of dim and its last dimension. rank_op = tf.rank(logits) axis_norm = axis % rank logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1)) # Do the actual softmax on its last dimension. output = _compute_2d_sparsemax(logits) output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1)) # Make shape inference work since transpose may erase its static shape. output.set_shape(shape) return output def sparsemax_loss(logits, sparsemax, labels, name: Optional[str] = None) ‑> tensorflow.python.framework.tensor.Tensor-
Sparsemax loss function (1).
Computes the generalized multi-label classification loss for the sparsemax function. The implementation is a reformulation of the original loss function such that it uses the sparsemax probability output instead of the internal \tau variable. However, the output is identical to the original loss function.
Args
logitsTensor. Must be one of the following types:float32,float64.sparsemaxTensor. Must have the same type aslogits.labelsTensor. Must have the same type aslogits.name- name for the operation (optional).
Returns
A
Tensor. Has the same type aslogits.Expand source code Browse git
def sparsemax_loss(logits, sparsemax, labels, name: Optional[str] = None) -> tf.Tensor: r"""Sparsemax loss function ([1]). Computes the generalized multi-label classification loss for the sparsemax function. The implementation is a reformulation of the original loss function such that it uses the sparsemax probability output instead of the internal \( \tau \) variable. However, the output is identical to the original loss function. [1]: https://arxiv.org/abs/1602.02068 Args: logits: `Tensor`. Must be one of the following types: `float32`, `float64`. sparsemax: `Tensor`. Must have the same type as `logits`. labels: `Tensor`. Must have the same type as `logits`. name: name for the operation (optional). Returns: A `Tensor`. Has the same type as `logits`. """ logits = tf.convert_to_tensor(logits, name="logits") sparsemax = tf.convert_to_tensor(sparsemax, name="sparsemax") labels = tf.convert_to_tensor(labels, name="labels") # In the paper, they call the logits z. # A constant can be substracted from logits to make the algorithm # more numerically stable in theory. However, there are really no major # source numerical instability in this algorithm. z = logits # sum over support # Use a conditional where instead of a multiplication to support z = -inf. # If z = -inf, and there is no support (sparsemax = 0), a multiplication # would cause 0 * -inf = nan, which is not correct in this case. sum_s = tf.where( tf.math.logical_or(sparsemax > 0, tf.math.is_nan(sparsemax)), sparsemax * (z - 0.5 * sparsemax), tf.zeros_like(sparsemax), ) # - z_k + ||q||^2 q_part = labels * (0.5 * labels - z) # Fix the case where labels = 0 and z = -inf, where q_part would # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for # z = -inf should be consideredself. # The code below also coveres the case where z = inf. Howeverm in this # caose the sparsemax will be nan, which means the sum_s will also be nan, # therefor this case doesn't need addtional special treatment. q_part_safe = tf.where( tf.math.logical_and(tf.math.equal(labels, 0), tf.math.is_inf(z)), tf.zeros_like(z), q_part, ) return tf.math.reduce_sum(sum_s + q_part_safe, axis=1) def sparsemax_loss_from_logits(y_true, logits_pred) ‑> tensorflow.python.framework.tensor.Tensor-
Expand source code Browse git
@tf.function @tf.keras.utils.register_keras_serializable(package="Addons") def sparsemax_loss_from_logits( y_true, logits_pred ) -> tf.Tensor: y_pred = sparsemax(logits_pred) loss = sparsemax_loss(logits_pred, y_pred, y_true) return loss
Classes
class SparsemaxLoss (from_logits: bool = True, reduction: str = 'sum_over_batch_size', name: str = 'sparsemax_loss')-
Sparsemax loss function.
Computes the generalized multi-label classification loss for the sparsemax function.
Because the sparsemax loss function needs both the probability output and the logits to compute the loss value,
from_logitsmust beTrue.Because it computes the generalized multi-label loss, the shape of both
y_predandy_truemust be[batch_size, num_classes].Args
from_logits- Whether
y_predis expected to be a logits tensor. Default isTrue, meaningy_predis the logits. reduction- (Optional) Type of
tf.keras.losses.Reductionto apply to loss. Default value isSUM_OVER_BATCH_SIZE. name- Optional name for the op
Expand source code Browse git
@tf.keras.utils.register_keras_serializable(package="Addons") class SparsemaxLoss(tf.keras.losses.Loss): """Sparsemax loss function. Computes the generalized multi-label classification loss for the sparsemax function. Because the sparsemax loss function needs both the probability output and the logits to compute the loss value, `from_logits` must be `True`. Because it computes the generalized multi-label loss, the shape of both `y_pred` and `y_true` must be `[batch_size, num_classes]`. Args: from_logits: Whether `y_pred` is expected to be a logits tensor. Default is `True`, meaning `y_pred` is the logits. reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss. Default value is `SUM_OVER_BATCH_SIZE`. name: Optional name for the op """ @typechecked def __init__( self, from_logits: bool = True, reduction: str = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, name: str = "sparsemax_loss", ): if from_logits is not True: raise ValueError("from_logits must be True") super().__init__(name=name, reduction=reduction) self.from_logits = from_logits def call(self, y_true, y_pred): return sparsemax_loss_from_logits(y_true, y_pred) def get_config(self): config = { "from_logits": self.from_logits, } base_config = super().get_config() return {**base_config, **config}Ancestors
- keras.src.losses.loss.Loss
- keras.src.saving.keras_saveable.KerasSaveable
Methods
def call(self, y_true, y_pred)-
Expand source code Browse git
def call(self, y_true, y_pred): return sparsemax_loss_from_logits(y_true, y_pred) def get_config(self)-
Expand source code Browse git
def get_config(self): config = { "from_logits": self.from_logits, } base_config = super().get_config() return {**base_config, **config}