piscis.training#
Functions#
|
Computes the loss and metrics for a given batch. |
|
Train the model for a single epoch. |
|
Validate the model for a single epoch. |
|
Train a SpotsModel. |
Module Contents#
- piscis.training.loss_fn(labels_pred: torch.Tensor, deltas_pred: torch.Tensor, labels: torch.Tensor, deltas: torch.Tensor, p: torch.Tensor, l2_loss_weight: float, max_distance: float, kernel_size: Sequence[int], temperature: float, epsilon: float) Tuple[torch.Tensor, Dict[str, float]]#
Computes the loss and metrics for a given batch.
- Parameters:
- labels_predtorch.Tensor
Predicted labels.
- deltas_predtorch.Tensor
Predicted displacement vectors.
- labels: torch.Tensor
Ground truth labels.
- deltastorch.Tensor
Ground truth displacement vectors.
- ptorch.Tensor
Number of ground truth spots in each image.
- l2_loss_weightfloat
Weight for the masked L2 loss term in the overall loss function.
- max_distancefloat
Maximum distance for matching predicted and ground truth displacement vectors.
- kernel_sizeSequence[int], optional
Kernel size of sum or max pooling operations. Default is (3, 3).
- temperaturefloat
Temperature parameter.
- epsilonfloat
Small constant for numerical stability.
- Returns:
- losstorch.Tensor
Overall loss value.
- metricsDict[str, float]
Dictionary containing the values of individual loss terms and the overall loss.
- piscis.training.train_epoch(model: piscis.models.spots.SpotsModel, dataloader: tqdm.tqdm, optimizer: torch.optim.Optimizer, l2_loss_weight: float, dilation_iterations: int, max_distance: float, temperature: float, epsilon: float, device: str | None) Dict[str, float]#
Train the model for a single epoch.
- Parameters:
- modelSpotsModel
Model to be trained.
- dataloadertqdm
DataLoader for the training data.
- optimizertorch.optim.Optimizer
Optimizer for updating model parameters.
- l2_loss_weightfloat
Weight for the masked L2 loss term in the overall loss function.
- dilation_iterationsint
Number of iterations to dilate ground truth labels
- max_distancefloat
Maximum distance for matching predicted and ground truth displacement vectors.
- temperaturefloat
Temperature parameter for softmax.
- epsilonfloat
Small constant for numerical stability.
- deviceOptional[str]
Device for training.
- Returns:
- train_metricsDict[str, float]
Dictionary containing average training metrics for the epoch.
- piscis.training.val_epoch(model: piscis.models.spots.SpotsModel, dataloader: torch.utils.data.DataLoader, l2_loss_weight: float, dilation_iterations: int, max_distance: float, temperature: float, epsilon: float, device: str | None) Dict[str, float]#
Validate the model for a single epoch.
- Parameters:
- modelSpotsModel
Model to be validated.
- dataloadertorch.utils.data.DataLoader
DataLoader for the validation data.
- l2_loss_weightfloat
Weight for the masked L2 loss term in the overall loss function.
- dilation_iterationsint
Number of iterations to dilate ground truth labels
- max_distancefloat
Maximum distance for matching predicted and ground truth displacement vectors.
- temperaturefloat
Temperature parameter for softmax.
- epsilonfloat
Small constant for numerical stability.
- deviceOptional[str]
Device for training.
- Returns:
- val_metricsDict[str, float]
Dictionary containing average validation metrics for the epoch.
- piscis.training.train_model(model_name: str, dataset_path: str | List[str] | Dict[str, float], initial_model_name: str | None = None, adjustment: str | None = 'standardize', input_size: Tuple[int, int] = (256, 256), random_seed: int = 0, batch_size: int = 4, num_workers: int = 0, learning_rate: float = 0.1, weight_decay: float = 1e-05, epochs: int = 500, warmup_fraction: float = 0.04, decay_fraction: float = 0.4, decay_transitions: int = 10, decay_factor: float = 0.5, l2_loss_weight: float = 0.1, dilation_iterations: int = 1, max_distance: float = 3.0, temperature: float = 0.05, epsilon: float = 1e-07, checkpoint_every: int = 10, device: str | None = 'cuda') None#
Train a SpotsModel.
- Parameters:
- model_namestr
Model name.
- dataset_pathUnion[str, List[str], Dict[str, float]]
Path to a dataset, path to a directory containing multiple datasets, a list of multiple dataset paths, or a dictionary of multiple dataset paths and their corresponding sampling weights. If a directory of datasets or a list is provided, all datasets in the directory or list will be loaded and concatenated with equal weights. If a dictionary is provided, the datasets will be loaded and concatenated with the specified weights.
- initial_model_nameOptional[str], optional
Name of an existing model to initialize the weights. Default is None.
- adjustmentOptional[str], optional
Adjustment type applied to images. Supported types are ‘normalize’ and ‘standardize’. Default is ‘standardize’.
- input_sizeTuple[int, int], optional
Input size used for training. Default is (256, 256).
- random_seedint, optional
Random seed used for initialization and training. Default is 0.
- batch_sizeint, optional
Batch size for training. Default is 4.
- num_workersint, optional
Number of workers for data loading. Default is 0.
- learning_ratefloat, optional
Learning rate for the optimizer. Default is 0.1.
- weight_decayfloat, optional
Strength of the weight decay regularization. Default is 1e-5.
- epochsint, optional
Number of epochs to train the model for. Default is 500.
- warmup_fractionfloat, optional
Fraction of epochs for learning rate warmup. Default is 0.04.
- decay_fractionfloat, optional
Fraction of epochs for learning rate decay. Default is 0.4.
- decay_transitionsint, optional
Number of times to decay the learning rate. Default is 10.
- decay_factorfloat, optional
Multiplicative factor of each learning rate decay transition. Default is 0.5.
- l2_loss_weightfloat, optional
Weight for the masked L2 loss term in the overall loss function. Default is 0.1.
- dilation_iterationsint, optional
Number of iterations to dilate ground truth labels to minimize class imbalance and misclassifications due to minor offsets. Default is 1.
- max_distancefloat, optional
Maximum distance for matching predicted and ground truth displacement vectors. Default is 3.0.
- temperaturefloat, optional
Temperature parameter for softmax. Default is 0.05.
- epsilonfloat, optional
Small constant for numerical stability. Default is 1e-7.
- checkpoint_everyint, optional
Number of epochs between saving model checkpoints. Default is 10.
- deviceOptional[str], optional
Device for training. Default is ‘cuda’.
- Raises:
- ValueError
If warmup_fraction + decay_fraction is greater than 1.