semipy.methods.utils.pimodel_loss
Warning
This section is in construction.
semipy.methods.utils.pimodel_loss(model, weak_x, strong_x, y, lbda, current_iter, pimodel_n_batches, warmup_length, debiased=False)
Parameters
- model - The model to train.
- weak_x - A
torch.tensor
representing one batch of weakly augmented data (both labelled and unlabelled). - strong_x - A
torch.tensor
representing the same batch of 'weak_x' but with strong augmentation. - y - A
torch.tensor
representing the labels of the current batch (unlabelled are labelled as -1). - lbda (float) - Balancing weight for unlabelled loss.
- current_iter (int) - Current training iteration (number of batches per epoch * current epoch + current batch).
- pimodel_n_batches (int) - Number of batches in a 'PiModel defined epoch' (defined as all items, labelled and unlabelled, have been seen).
- warmup_length (int) - Number of epochs for PiModel defined warmup.
- debiased (bool) - To activate or not safe-SSL (Schmutz et al.) via debiasing. Default:
False