piscis.losses#
Attributes#
Functions#
|
Compute the SmoothF1 loss. |
|
Compute the L2 loss over masked pixels. |
|
Reduce the loss. |
|
Wrap a loss function for vectorization and loss reduction. |
Module Contents#
- piscis.losses.smoothf1_loss(labels_pred: torch.Tensor, deltas_pred: torch.Tensor, deltas: torch.Tensor, p: torch.Tensor, max_distance: float = 3.0, kernel_size: Sequence[int] = (3, 3), temperature: float = 0.05, epsilon: float = 1e-07) torch.Tensor#
Compute the SmoothF1 loss.
- Parameters:
- labels_predtorch.Tensor
Predicted labels.
- deltas_predtorch.Tensor
Predicted displacement vectors.
- deltastorch.Tensor
Ground truth displacement vectors.
- ptorch.Tensor
Number of ground truth spots in each image.
- max_distancefloat, optional
Maximum distance for matching predicted and ground truth displacement vectors. Default is 3.
- temperaturefloat, optional
Temperature parameter for softmax. Default is 0.05.
- kernel_sizeSequence[int], optional
Kernel size of sum or max pooling operations. Default is (3, 3).
- epsilonfloat, optional
Small constant for numerical stability. Default is 1e-7.
- Returns:
- smoothf1torch.Tensor
SmoothF1 loss.
- piscis.losses.masked_l2_loss(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, epsilon: float = 1e-07) torch.Tensor#
Compute the L2 loss over masked pixels.
- Parameters:
- inputtorch.Tensor
Predicted values.
- targettorch.Tensor
Ground truth values.
- masktorch.Tensor
Mask tensor where each pixel is a boolean for whether it should be included in the loss computation.
- epsilonfloat, optional
Small constant for numerical stability. Default is 1e-7.
- Returns:
- rmsejax.Array
Masked root-mean-square error.
- piscis.losses.reduce_loss(loss: torch.Tensor, reduction: str | None = 'mean') torch.Tensor#
Reduce the loss.
- Parameters:
- losstorch.Tensor
Loss tensor to be reduced.
- reductionOptional[str], optional
Loss reduction method. Supported methods are ‘mean’ and ‘sum’. Default is ‘mean’.
- Returns:
- losstorch.Tensor
Reduced loss.
- Raises:
- ValueError
If the reduction method is not supported.
- piscis.losses.wrap_loss_fn(loss_fn: Callable, axis: int = 0, reduction: str | None = 'mean') Callable#
Wrap a loss function for vectorization and loss reduction.
- Parameters:
- loss_fnCallable
Loss function.
- axisint, optional
Axis to vectorize over. Default is 0.
- reductionOptional[str], optional
Loss reduction method. Default is ‘mean’.
- Returns:
- wrapped_loss_fnCallable
Wrapped loss function.
- piscis.losses.mean_smoothf1_loss#
- piscis.losses.mean_masked_l2_loss#