semipy.methods.utils.adamatch_loss
Warning
This section is in construction.
semipy.methods.utils.adamatch_loss(model, weak_x, strong_x, y, lbda, threshold, dist_align, debiased=False)
Parameters
- model - The model to train.
- weak_x - A
torch.tensor
representing one batch of weakly augmented data (both labelled and unlabelled). - strong_x - A
torch.tensor
representing the same batch of 'weak_x' but with strong augmentation. - y - A
torch.tensor
representing the labels of the current batch (unlabelled are labelled as -1). - lbda (float) - Balancing weight for unlabelled loss.
- threshold (float) - Probability threshold for pseudo-labelling.
- dist_align - A semipy.methods.utils.DistAlign class to perform distribution alignment.
- debiased (bool) - To activate or not safe-SSL (Schmutz et al.) via debiasing. Default:
False