piscis.training
===============

.. py:module:: piscis.training


Functions
---------

.. autoapisummary::

   piscis.training.loss_fn
   piscis.training.train_epoch
   piscis.training.val_epoch
   piscis.training.train_model


Module Contents
---------------

.. py:function:: loss_fn(labels_pred: torch.Tensor, deltas_pred: torch.Tensor, labels: torch.Tensor, deltas: torch.Tensor, p: torch.Tensor, l2_loss_weight: float, max_distance: float, kernel_size: Sequence[int], temperature: float, epsilon: float) -> Tuple[torch.Tensor, Dict[str, float]]

   
   Computes the loss and metrics for a given batch.


   :Parameters:

       **labels_pred** : torch.Tensor
           Predicted labels.

       **deltas_pred** : torch.Tensor
           Predicted displacement vectors.

       **labels: torch.Tensor**
           Ground truth labels.

       **deltas** : torch.Tensor
           Ground truth displacement vectors.

       **p** : torch.Tensor
           Number of ground truth spots in each image.

       **l2_loss_weight** : float
           Weight for the masked L2 loss term in the overall loss function.

       **max_distance** : float
           Maximum distance for matching predicted and ground truth displacement vectors.

       **kernel_size** : Sequence[int], optional
           Kernel size of sum or max pooling operations. Default is (3, 3).

       **temperature** : float
           Temperature parameter.

       **epsilon** : float
           Small constant for numerical stability.



   :Returns:

       **loss** : torch.Tensor
           Overall loss value.

       **metrics** : Dict[str, float]
           Dictionary containing the values of individual loss terms and the overall loss.











   ..
       !! processed by numpydoc !!

.. py:function:: train_epoch(model: piscis.models.spots.SpotsModel, dataloader: tqdm.tqdm, optimizer: torch.optim.Optimizer, l2_loss_weight: float, dilation_iterations: int, max_distance: float, temperature: float, epsilon: float, device: Optional[str]) -> Dict[str, float]

   
   Train the model for a single epoch.


   :Parameters:

       **model** : SpotsModel
           Model to be trained.

       **dataloader** : tqdm
           DataLoader for the training data.

       **optimizer** : torch.optim.Optimizer
           Optimizer for updating model parameters.

       **l2_loss_weight** : float
           Weight for the masked L2 loss term in the overall loss function.

       **dilation_iterations** : int
           Number of iterations to dilate ground truth labels

       **max_distance** : float
           Maximum distance for matching predicted and ground truth displacement vectors.

       **temperature** : float
           Temperature parameter for softmax.

       **epsilon** : float
           Small constant for numerical stability.

       **device** : Optional[str]
           Device for training.



   :Returns:

       **train_metrics** : Dict[str, float]
           Dictionary containing average training metrics for the epoch.











   ..
       !! processed by numpydoc !!

.. py:function:: val_epoch(model: piscis.models.spots.SpotsModel, dataloader: torch.utils.data.DataLoader, l2_loss_weight: float, dilation_iterations: int, max_distance: float, temperature: float, epsilon: float, device: Optional[str]) -> Dict[str, float]

   
   Validate the model for a single epoch.


   :Parameters:

       **model** : SpotsModel
           Model to be validated.

       **dataloader** : torch.utils.data.DataLoader
           DataLoader for the validation data.

       **l2_loss_weight** : float
           Weight for the masked L2 loss term in the overall loss function.

       **dilation_iterations** : int
           Number of iterations to dilate ground truth labels

       **max_distance** : float
           Maximum distance for matching predicted and ground truth displacement vectors.

       **temperature** : float
           Temperature parameter for softmax.

       **epsilon** : float
           Small constant for numerical stability.

       **device** : Optional[str]
           Device for training.



   :Returns:

       **val_metrics** : Dict[str, float]
           Dictionary containing average validation metrics for the epoch.











   ..
       !! processed by numpydoc !!

.. py:function:: train_model(model_name: str, dataset_path: Union[str, List[str], Dict[str, float]], initial_model_name: Optional[str] = None, adjustment: Optional[str] = 'standardize', input_size: Tuple[int, int] = (256, 256), random_seed: int = 0, batch_size: int = 4, num_workers: int = 0, learning_rate: float = 0.1, weight_decay: float = 1e-05, epochs: int = 500, warmup_fraction: float = 0.04, decay_fraction: float = 0.4, decay_transitions: int = 10, decay_factor: float = 0.5, l2_loss_weight: float = 0.1, dilation_iterations: int = 1, max_distance: float = 3.0, temperature: float = 0.05, epsilon: float = 1e-07, checkpoint_every: int = 10, device: Optional[str] = 'cuda') -> None

   
   Train a SpotsModel.


   :Parameters:

       **model_name** : str
           Model name.

       **dataset_path** : Union[str, List[str], Dict[str, float]]
           Path to a dataset, path to a directory containing multiple datasets, a list of multiple dataset paths, or a
           dictionary of multiple dataset paths and their corresponding sampling weights. If a directory of datasets or a
           list is provided, all datasets in the directory or list will be loaded and concatenated with equal weights. If
           a dictionary is provided, the datasets will be loaded and concatenated with the specified weights.

       **initial_model_name** : Optional[str], optional
           Name of an existing model to initialize the weights. Default is None.

       **adjustment** : Optional[str], optional
           Adjustment type applied to images. Supported types are 'normalize' and 'standardize'. Default is 'standardize'.

       **input_size** : Tuple[int, int], optional
           Input size used for training. Default is (256, 256).

       **random_seed** : int, optional
           Random seed used for initialization and training. Default is 0.

       **batch_size** : int, optional
           Batch size for training. Default is 4.

       **num_workers** : int, optional
           Number of workers for data loading. Default is 0.

       **learning_rate** : float, optional
           Learning rate for the optimizer. Default is 0.1.

       **weight_decay** : float, optional
           Strength of the weight decay regularization. Default is 1e-5.

       **epochs** : int, optional
           Number of epochs to train the model for. Default is 500.

       **warmup_fraction** : float, optional
           Fraction of epochs for learning rate warmup. Default is 0.04.

       **decay_fraction** : float, optional
           Fraction of epochs for learning rate decay. Default is 0.4.

       **decay_transitions** : int, optional
           Number of times to decay the learning rate. Default is 10.

       **decay_factor** : float, optional
           Multiplicative factor of each learning rate decay transition. Default is 0.5.

       **l2_loss_weight** : float, optional
           Weight for the masked L2 loss term in the overall loss function. Default is 0.1.

       **dilation_iterations** : int, optional
           Number of iterations to dilate ground truth labels to minimize class imbalance and misclassifications due to
           minor offsets. Default is 1.

       **max_distance** : float, optional
           Maximum distance for matching predicted and ground truth displacement vectors. Default is 3.0.

       **temperature** : float, optional
           Temperature parameter for softmax. Default is 0.05.

       **epsilon** : float, optional
           Small constant for numerical stability. Default is 1e-7.

       **checkpoint_every** : int, optional
           Number of epochs between saving model checkpoints. Default is 10.

       **device** : Optional[str], optional
           Device for training. Default is 'cuda'.







   :Raises:

       ValueError
           If warmup_fraction + decay_fraction is greater than 1.







   ..
       !! processed by numpydoc !!

