piscis.training#

Functions#

loss_fn(→ Tuple[torch.Tensor, Dict[str, float]])

Computes the loss and metrics for a given batch.

train_epoch(→ Dict[str, float])

Train the model for a single epoch.

val_epoch(→ Dict[str, float])

Validate the model for a single epoch.

train_model(, random_seed, batch_size, num_workers, ...)

Train a SpotsModel.

Module Contents#

piscis.training.loss_fn(labels_pred: torch.Tensor, deltas_pred: torch.Tensor, labels: torch.Tensor, deltas: torch.Tensor, p: torch.Tensor, l2_loss_weight: float, max_distance: float, kernel_size: Sequence[int], temperature: float, epsilon: float) Tuple[torch.Tensor, Dict[str, float]]#

Computes the loss and metrics for a given batch.

Parameters:
labels_predtorch.Tensor

Predicted labels.

deltas_predtorch.Tensor

Predicted displacement vectors.

labels: torch.Tensor

Ground truth labels.

deltastorch.Tensor

Ground truth displacement vectors.

ptorch.Tensor

Number of ground truth spots in each image.

l2_loss_weightfloat

Weight for the masked L2 loss term in the overall loss function.

max_distancefloat

Maximum distance for matching predicted and ground truth displacement vectors.

kernel_sizeSequence[int], optional

Kernel size of sum or max pooling operations. Default is (3, 3).

temperaturefloat

Temperature parameter.

epsilonfloat

Small constant for numerical stability.

Returns:
losstorch.Tensor

Overall loss value.

metricsDict[str, float]

Dictionary containing the values of individual loss terms and the overall loss.

piscis.training.train_epoch(model: piscis.models.spots.SpotsModel, dataloader: tqdm.tqdm, optimizer: torch.optim.Optimizer, l2_loss_weight: float, dilation_iterations: int, max_distance: float, temperature: float, epsilon: float, device: str | None) Dict[str, float]#

Train the model for a single epoch.

Parameters:
modelSpotsModel

Model to be trained.

dataloadertqdm

DataLoader for the training data.

optimizertorch.optim.Optimizer

Optimizer for updating model parameters.

l2_loss_weightfloat

Weight for the masked L2 loss term in the overall loss function.

dilation_iterationsint

Number of iterations to dilate ground truth labels

max_distancefloat

Maximum distance for matching predicted and ground truth displacement vectors.

temperaturefloat

Temperature parameter for softmax.

epsilonfloat

Small constant for numerical stability.

deviceOptional[str]

Device for training.

Returns:
train_metricsDict[str, float]

Dictionary containing average training metrics for the epoch.

piscis.training.val_epoch(model: piscis.models.spots.SpotsModel, dataloader: torch.utils.data.DataLoader, l2_loss_weight: float, dilation_iterations: int, max_distance: float, temperature: float, epsilon: float, device: str | None) Dict[str, float]#

Validate the model for a single epoch.

Parameters:
modelSpotsModel

Model to be validated.

dataloadertorch.utils.data.DataLoader

DataLoader for the validation data.

l2_loss_weightfloat

Weight for the masked L2 loss term in the overall loss function.

dilation_iterationsint

Number of iterations to dilate ground truth labels

max_distancefloat

Maximum distance for matching predicted and ground truth displacement vectors.

temperaturefloat

Temperature parameter for softmax.

epsilonfloat

Small constant for numerical stability.

deviceOptional[str]

Device for training.

Returns:
val_metricsDict[str, float]

Dictionary containing average validation metrics for the epoch.

piscis.training.train_model(model_name: str, dataset_path: str | List[str] | Dict[str, float], initial_model_name: str | None = None, adjustment: str | None = 'standardize', input_size: Tuple[int, int] = (256, 256), random_seed: int = 0, batch_size: int = 4, num_workers: int = 0, learning_rate: float = 0.1, weight_decay: float = 1e-05, epochs: int = 500, warmup_fraction: float = 0.04, decay_fraction: float = 0.4, decay_transitions: int = 10, decay_factor: float = 0.5, l2_loss_weight: float = 0.1, dilation_iterations: int = 1, max_distance: float = 3.0, temperature: float = 0.05, epsilon: float = 1e-07, checkpoint_every: int = 10, device: str | None = 'cuda') None#

Train a SpotsModel.

Parameters:
model_namestr

Model name.

dataset_pathUnion[str, List[str], Dict[str, float]]

Path to a dataset, path to a directory containing multiple datasets, a list of multiple dataset paths, or a dictionary of multiple dataset paths and their corresponding sampling weights. If a directory of datasets or a list is provided, all datasets in the directory or list will be loaded and concatenated with equal weights. If a dictionary is provided, the datasets will be loaded and concatenated with the specified weights.

initial_model_nameOptional[str], optional

Name of an existing model to initialize the weights. Default is None.

adjustmentOptional[str], optional

Adjustment type applied to images. Supported types are ‘normalize’ and ‘standardize’. Default is ‘standardize’.

input_sizeTuple[int, int], optional

Input size used for training. Default is (256, 256).

random_seedint, optional

Random seed used for initialization and training. Default is 0.

batch_sizeint, optional

Batch size for training. Default is 4.

num_workersint, optional

Number of workers for data loading. Default is 0.

learning_ratefloat, optional

Learning rate for the optimizer. Default is 0.1.

weight_decayfloat, optional

Strength of the weight decay regularization. Default is 1e-5.

epochsint, optional

Number of epochs to train the model for. Default is 500.

warmup_fractionfloat, optional

Fraction of epochs for learning rate warmup. Default is 0.04.

decay_fractionfloat, optional

Fraction of epochs for learning rate decay. Default is 0.4.

decay_transitionsint, optional

Number of times to decay the learning rate. Default is 10.

decay_factorfloat, optional

Multiplicative factor of each learning rate decay transition. Default is 0.5.

l2_loss_weightfloat, optional

Weight for the masked L2 loss term in the overall loss function. Default is 0.1.

dilation_iterationsint, optional

Number of iterations to dilate ground truth labels to minimize class imbalance and misclassifications due to minor offsets. Default is 1.

max_distancefloat, optional

Maximum distance for matching predicted and ground truth displacement vectors. Default is 3.0.

temperaturefloat, optional

Temperature parameter for softmax. Default is 0.05.

epsilonfloat, optional

Small constant for numerical stability. Default is 1e-7.

checkpoint_everyint, optional

Number of epochs between saving model checkpoints. Default is 10.

deviceOptional[str], optional

Device for training. Default is ‘cuda’.

Raises:
ValueError

If warmup_fraction + decay_fraction is greater than 1.