MedMNIST Tutorial¶
In this tutorial, we will start a simple training with FixMatch on 'pathmnist', a 9-classes dataset from MedMNIST composed of 89,996 training samples. We will use only 6,000 samples as labelled and the rest will be unlabelled. This dataset also comes with a validation set of 10,004 images and a test set of 7,180 images.
import semipy as smp
import torch
To simplify things, we will use a configuration file that you can find here. That way, it is easier to define parameters for our training.
args = smp.tools.get_config('config.yaml')
Then, after reading our config file with 'get_config', we will then retrieve our datasets. Note than all those steps in this tutorial can be done automatically by using the "SSLTrainer" class defined in SemiPy. But as this is a tutorial, it is better to detail everything. Hence, in the next cell in 'get_medmnist', we have to manually define the number of labelled samples and validation proportion, even if those parameters are present in the configuration file. Also note that 'augmentation' is set to True
because we need to use strong augmentation for FixMatch training.
sets = smp.datasets.get_medmnist(name='pathmnist', num_labelled=6000, augmentation=True, include_labelled=True)
C:\Users\lboiteau\Documents\Demos\semipy\datasets\medmnist.py:44: UserWarning: Warning: valid_proportion is set to 0 or not defined. Length of validation set will be the length of the original set from MedMNIST warnings.warn('Warning: valid_proportion is set to 0 or not defined. '
Using downloaded and verified file: C:\Users\lboiteau\.medmnist\pathmnist.npz Using downloaded and verified file: C:\Users\lboiteau\.medmnist\pathmnist.npz Using downloaded and verified file: C:\Users\lboiteau\.medmnist\pathmnist.npz
Next, let's download a model to train. We will choose a simple resnet18 from PyTorch.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = smp.models.get_model('resnet18', num_classes=9)
model = model.to(device)
We choose an SGD optimizer with a learning rate of 0.03.
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
Next is a very important part: the creation of the dataloader. In a usual training process, we would just use a simple dataloader. Here, we need to separate labelled and unlabelled samples. That's why we will use the JointSampler from SemiPy. It allows to use only one dataset composed of both labelled and unlabelled items. Simply choose a batch size and the proportion of labelled items you want in each batch and you are good to go !
sampler = smp.sampler.JointSampler(dataset=sets['Train'], batch_size=64, proportion=0.5)
dataloader = torch.utils.data.DataLoader(sets['Train'], batch_sampler=sampler)
val_dataloader = torch.utils.data.DataLoader(sets['Validation'], batch_size=64, shuffle=False)
Finally, it's time to choose an SSL method for training. We will choose FixMatch, and as we are not using PyTorch Lightning in this tutorial, we will use the simple trainer included in SemiPy:
trainer = smp.methods.FixMatch(args, model, dataloader, val_dataloader, optimizer, scheduler=None, num_classes=9)
trainer.training()
0%| | 0/25 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
[1.6222473915587081, 1.1791775106115545, 1.0396045560532428, 0.9519544747915674, 0.8916148728829749, 0.8484063008998303, 0.8072013575980004, 0.7802366132431842, 0.7285392368410496, 0.7350093584428442, 0.6964399213803575, 0.6474804099886975, 0.6309603810944455, 0.6154432499662359, 0.6662413840915294, 0.5623518959321874, 0.5926185534038442, 0.5563194929285252, 0.5360793276353085, 0.5329677575921759, 0.5152371287187363, 0.5018856583282034, 0.47833075889564575, 0.47020974707730273, 0.44847835243699397]
trainer.eval(0, 0)
{'VALIDATION/MulticlassAccuracy': tensor(0.2592), 'VALIDATION/Loss': tensor(1.9023)}