Skip to content

semipy.methods.utils.vat_loss

Warning

This section is in construction.

    semipy.methods.utils.vat_loss(model, weak_x, y, lbda, debiased=False)
This function computes the labelled/unlabelled loss with respect to VAT method.

Parameters

  • model - The model to train.
  • weak_x - A torch.tensor representing one batch of weakly augmented data (both labelled and unlabelled).
  • y - A torch.tensor representing the labels of the current batch (unlabelled are labelled as -1).
  • lbda (float) - Balancing weight for unlabelled loss.
  • debiased (bool) - To activate or not safe-SSL (Schmutz et al.) via debiasing. Default: False