Skip to content

semipy.methods.abstractMethod

    class semipy.method.abstractMethod(args, model, dataloader, val_dataloaders,
                                       optimizer, scheduler, num_classes: int) -> None

Base class used as a structure to build SSL methods on top of it. It defines every tool needed for training. Each method that inherit from this class needs to define the self.training() method.

Parameters

  • args (dict) - Dictionary of parameters. To have a complete dictionary, use semipy.tools.get_config
  • model - Model to train. It can be a model from torchvision.models or a custom PyTorch model
  • dataloader (torch.utils.data.DataLoader) - Dataloader for training. It should use the provided semipy.sampler.JointSampler (or DistributedJointSampler) sampler
  • val_dataloader (torch.utils.data.DataLoader) - Dataloader for validation
  • optimizer (torch.optim.Optimizer) - Optimization algorithm
  • scheduler (torch.optim.lr_scheduler.LRScheduler, optional) - Learning rate scheduler
  • num_classes (int) - Number of classes in the dataset

    Methods

        eval()
    

    Performs an evaluation of the model. Used either during validation or test.


        save(name)
    

    Saving model using torch.save.


        training()
    

    Not implemented. This method needs to be overridden when creating a new SSL method while using the base class.


        val_test_pass(dataloader, metrics_dict, dataloader_idx=0, action='validation')
    

    Used by self.eval() to pass each validation or test samples in the model and compute metrics.

    Parameters

    • dataloader (torch.utils.data.DataLoader) - Validation or test datalaoder
    • metrics_dict (dict) - Dictionary of metrics
    • dataloader_idx (int) - Index of dataloader in case of using multiple validation or test sets. Default: 0
    • action (str) - Which action to perform. Should be validation or test. Default: validation