semipy.methods.abstractMethod
class semipy.method.abstractMethod(args, model, dataloader, val_dataloaders,
optimizer, scheduler, num_classes: int) -> None
Base class used as a structure to build SSL methods on top of it. It defines every tool needed for training. Each method that inherit from this class needs to define the self.training()
method.
Parameters
- args (dict) - Dictionary of parameters. To have a complete dictionary, use
semipy.tools.get_config
- model - Model to train. It can be a model from
torchvision.models
or a custom PyTorch model - dataloader (torch.utils.data.DataLoader) - Dataloader for training. It should use the provided
semipy.sampler.JointSampler
(orDistributedJointSampler
) sampler - val_dataloader (torch.utils.data.DataLoader) - Dataloader for validation
- optimizer (torch.optim.Optimizer) - Optimization algorithm
- scheduler (torch.optim.lr_scheduler.LRScheduler, optional) - Learning rate scheduler
-
num_classes (int) - Number of classes in the dataset
Methods
Performs an evaluation of the model. Used either during validation or test.
Saving model using
torch.save
.Not implemented. This method needs to be overridden when creating a new SSL method while using the base class.
Used by
self.eval()
to pass each validation or test samples in the model and compute metrics.Parameters
- dataloader (torch.utils.data.DataLoader) - Validation or test datalaoder
- metrics_dict (dict) - Dictionary of metrics
- dataloader_idx (int) - Index of dataloader in case of using multiple validation or test sets. Default: 0
- action (str) - Which action to perform. Should be
validation
ortest
. Default:validation