# TP3: Image Segmentation With a U-Net - From CPU to GPU

The goal of this notebook is to introduce the notion of **hardware acceleration** in the context of deep learning. We will use the application of image segmentation with a U-Net to illustrate the difference in performance between a CPU and a GPU.

**Image segmentation** is an important task in  computer vision. The goal is to find multiple areas in an image and to assign labels to these area.  It provides a different kind of information than: 
- **image classification** caracterizes images with global labels;
- **object detection** usually relies on finding bounding-boxes around  detected objects

Segmentation is useful and can be used in real-world applications such as medical imaging, clothes segmentation, flooding maps, self-driving cars, etc. There are two types of image segmentation:
- Semantic segmentation: classify each pixel with a label.
- Instance segmentation: classify each pixel and differentiate each object instance.

U-Net is a semantic segmentation technique [originally proposed for medical imaging segmentation](https://arxiv.org/abs/1505.04597). Itâ€™s one of the earlier deep learning segmentation models. This architecture is still widely used in more advanced models like Generative Adversarial or Diffusion Network. 

The model architecture is fairly simple: an encoder (for downsampling) and a decoder (for upsampling) with skip connections. U-Net is only based on convolutions. More specifically, the output classification is done at pixel level with a *(1,1)* convolution. It has therefore the following advantages: 
- parameter and data efficiency, 
- independent of the input size. 

The following image is taken from the original paper:

<img src="https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png"  width="60%" height="30%">


In this session, we will consider a binary segmentation task.  

# 1.  Cell nuclei segmentation: the dataset

Cell nuclei segmentation is an essential step in the biological analysis of microscopy images. 
This segmentation can be manually achieved with dedicated software, however it is very costly. 
In this lab session, the starting point is this [nature paper](https://www.nature.com/articles/s41597-020-00608-w). To quote some part of the paper: 

Fully-automated nuclear image segmentation is the prerequisite to ensure statistically significant, quantitative analyses of tissue preparations,applied in digital pathology or quantitative microscopy. The design of segmentation methods that work independently of the tissue type or preparation is complex, due to variations in nuclear morphology, staining intensity, cell density and nuclei aggregations. Machine learning-based segmentation methods can overcome these challenges, however high quality expert-annotated images are required for training. Currently, the limited number of annotated fluorescence image datasets publicly available do not cover a broad range of tissues and preparations. We present a comprehensive, annotated dataset including tightly aggregated nuclei of multiple tissues for the training of machine learning-based nuclear segmentation algorithms. The proposed dataset covers sample preparation methods frequently used in quantitative immunofluorescence microscopy. 

To spare some preprocessing time, this lab session starts with this pickle

In [1]:
import numpy as np 
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import pickle
import time
from tqdm.notebook import trange, tqdm

In [None]:
!wget https://www.lamsade.dauphine.fr/~averine/DL3AIISO/nuclei_cells_segmentations.pck .

In [None]:
fn = "nuclei_cells_segmentations.pck"
with open(fn, 'rb') as f:
    X, Y = pickle.load(f)
print(f"Nuclei images: {X.shape}, Cell images: {Y.shape}")
N = X.shape[0]

Let us visualize some images from the dataset.

In [None]:
plt.figure(figsize=(15,3))
for i in range(0, 10):
    plt.subplot(2,10,i+1)
    plt.imshow(X[i].squeeze())
    plt.axis('off')
    if i == 0:
        plt.ylabel("Nuclei")
    plt.subplot(2,10,i+11)
    plt.imshow(Y[i].squeeze())
    plt.axis('off')
    if i == 0:
        plt.ylabel("Cells")
plt.show()

This pickle contains a modified version of the dataset: 
- the same amount of images
- all the images are resized to 128,128
- the segmentation task is converted in a binary pixel classification: nuclei or not. 

The goal is now to train a U-Net on this dataset (69 images for training and 10 for "test"). 

In [5]:
M = 69
trainset = X[:M], Y[:M]
testset = X[M:], Y[M:]

# 2 U-Net Architecture

Following the previous picture of U-Net, the network is composed of 3 parts: encoder, bottleneck, decoder.  
These three steps rely on a convolutional block (convolution, relu, convolution, relu) 

The first step is the **encoder**. The goal is to compress the  "geometrical" information with local features. The encoder first applies a convolution of kernel size (3,3) to extract $F=64$ features. Then the information is compressed using max-pooling (factor 2). The next step does the same:  extract $2\times F=128$ features from the $F=64$, then compression with max-pooling. This operation is repeated 4 times in total to get at the end $F\times 8 = 512$ channels that represent global features extracted from the input image. 

The **bottleneck** layer is a convolutional layer which doubles the number of channels. The idea is to create a "dense" representation of the image to gather both global and local features. 

The **decoder** part is similar to the encoder part but reversed. While we used max-pooling for downsampling in the encoder, the upsampling operation consists in **transposed convolution**. The goal is to increase (so upsample) the spatial dimensions of intermediate feature maps. 

The last peculiarity is the output layer for classification at the pixel level. In U-Net this last layer is (once again) a convolutional layer. This means that with the last hidden layer, we recover the same spatial dimension as the input with $F$ feature maps. The classification is carried out for each pixel independently, but the decision is based on $F$ features that encode global information. 



## 2. U-Net: step by step

Now the goal is to implement U-Net. As a proposed roadmap we propose the following step: 
- a function to create a convolutional block
- a module for the encoder
- a module for the decoder
- and a U-Net module to wrap everything

The number of feature map ($F=64$ in the original work) must be a variable of the UNet. For the first round of experiment, we can use $F=8$.  

In [6]:
def make_conv_block(in_c,out_c, 
                    kernel_size=3, 
                    stride=1, 
                    padding=1):
    """ Convolutional block:
        A basic block for U-Net. Twice the sequence  
        - 3x3 convolutions (stride 1, padding = 1) 
        - relu + BatchNorm 
        It follows a two 3x3 convolutional layer, each followed by a batch normalization and a relu activation.
        Args:
        - in_c: number of input channels
        - out_c: number of output channels
        - kernel_size: size of the kernel
        - stride: stride of the convolution
        - padding: padding of the convolution
        Output:
        - a Sequential module
    """
    mod = nn.Sequential(
                nn.Conv2d(in_channels=in_c,
                            out_channels=out_c, 
                            kernel_size=kernel_size, 
                            stride=stride,
                            padding=padding
                    ),
                nn.ReLU(),
                nn.BatchNorm2d(num_features=out_c),
                nn.Conv2d(in_channels=out_c,
                            out_channels=out_c, 
                            kernel_size=kernel_size, 
                            stride=stride,
                            padding=padding
                    ),
                nn.BatchNorm2d(num_features=out_c),
                nn.ReLU()
                )
    return mod

class encoder_block(nn.Module):
    """ Encoder block:
        - a conv block
        - followed by a MaxPooling: the height and width is divided by 2 
        We need both the output before and after the max-pooling. 
        - the output before the maxpooling is used for the "residual" connection
        - the output after is given to the next encoder block or the bottleneck block. 
       
        Remark: we could use a Sequential module but for the forward we need 
        to get x and p
    """
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = make_conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p


class decoder_block(nn.Module):
    """ Decoder block:
        - First upsampling is done with a ConvTranspose2D
        - The "residual" connection is implemented with the cat function
        - Then a conv_block is applied. 
        The ConvTranspose doubles the height and width.
        The forward call takes two args: the inputs from the previous block and the residual input. 
    """
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = make_conv_block(out_c+out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x



class UNet(nn.Module):
    def __init__(self,factor=64):
        super().__init__()
        self.f = factor
        """ Encoder """
        self.e1 = encoder_block(1, self.f)
        self.e2 = encoder_block(self.f, self.f*2)
        self.e3 = encoder_block(self.f*2, self.f*4)
        self.e4 = encoder_block(self.f*4, self.f*8)

        """ Bottleneck """
        self.b = make_conv_block(self.f*8, self.f*16)

        """ Decoder """
        self.d1 = decoder_block(self.f*16, self.f*8)
        self.d2 = decoder_block(self.f*8, self.f*4)
        self.d3 = decoder_block(self.f*4, self.f*2)
        self.d4 = decoder_block(self.f*2, self.f)

        """ Classifier """
        self.outputs = nn.Conv2d(self.f, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        """ Bottleneck """
        b = self.b(p4)

        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        """ Classifier """
        outputs = self.outputs(d4)
        return outputs

    
    
class LightUNet(nn.Module):
    def __init__(self,factor=8):
        super().__init__()
        self.f = factor
        """ Encoder """
        self.e1 = encoder_block(1, self.f)
        
        """ Bottleneck """
        self.b = make_conv_block(self.f, self.f*2)

        """ Decoder """
        self.d1 = decoder_block(self.f*2, self.f)

        """ Classifier """
        self.outputs = nn.Conv2d(self.f, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        
        """ Bottleneck """
        b = self.b(p1)

        """ Decoder """
        d1 = self.d1(b, s1)
        
        """ Classifier """
        outputs = self.outputs(d1)
        return outputs


We can now define two different U-Net with different number of features. Let us compare the speed of inference between the two models.
By default, the model is loaded on CPU. We will then compare the speed of inference on CPU and GPU.


First, let us check if your computer has a GPU.

In [None]:
print(f"GPU available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print(f"Number of GPUs: {torch.cuda.device_count()}")


<font color='blue'>TODO:</font> 
- Run the following cell to compare the speed of inference on CPU and GPU.
- Comment the results.
- Consider running your code on Colab if your own computer does not have a GPU.

In [None]:
model_cpu = UNet(factor=64)
print(f"Number of parameters: {sum(p.numel() for p in model_cpu.parameters())}")
light_model_cpu = LightUNet(factor=8)
print(f"Number of parameters: {sum(p.numel() for p in light_model_cpu.parameters())}")

times_model_cpu = []
times_light_model_cpu = []
model_cpu.eval()
light_model_cpu.eval()
if torch.cuda.is_available():
    times_model_gpu = []
    times_light_model_gpu = []
    model_gpu = UNet(factor=64).cuda()
    light_model_gpu = LightUNet(factor=8).cuda()
    model_gpu.eval()
    light_model_gpu.eval()

for _ in trange(10):
    t0 = time.time()
    model_cpu(trainset[0])
    t1 = time.time()
    times_model_cpu.append(t1-t0)
    t0 = time.time()
    light_model_cpu(trainset[0])
    t1 = time.time()
    times_light_model_cpu.append(t1-t0)
    if torch.cuda.is_available():
        t0 = time.time()
        model_gpu(trainset[0].cuda())
        t1 = time.time()
        times_model_gpu.append(t1-t0)
        t0 = time.time()
        light_model_gpu(trainset[0].cuda())
        t1 = time.time()
        times_light_model_gpu.append(t1-t0)
print(f"UNet CPU: {np.mean(times_model_cpu):.3f} s")
if torch.cuda.is_available():
    print(f"UNet GPU: {np.mean(times_model_gpu):.3f} s")
    print(f"Inference speedup: {np.mean(times_model_cpu)/np.mean(times_model_gpu):.1f}")
print(f"LightUNet CPU: {np.mean(times_light_model_cpu):.3f} s")
if torch.cuda.is_available():
    print(f"LightUNet GPU: {np.mean(times_light_model_gpu):.3f} s")
    print(f"Inference speedup: {np.mean(times_light_model_cpu)/np.mean(times_light_model_gpu):.1f}")
    del model_gpu, light_model_gpu
del model_cpu, light_model_cpu

We can also compare the speed of inference of GPU with different batch sizes.

In [None]:
if torch.cuda.is_available():
    time_inference_cpu_mean = []
    time_inference_gpu_mean = []
    time_inference_cpu_std = []
    time_inference_gpu_std = []
    light_model_cpu = LightUNet(factor=8)
    light_model_cpu.eval()
    light_model_gpu = LightUNet(factor=8).cuda()
    light_model_gpu.eval()
    with torch.no_grad():
        for i in trange(0, 7):
            time_inference_cpu = []
            time_inference_gpu = []
            batch_size = 2**i
            for i in range(20):
                t0 = time.time()
                light_model_cpu(trainset[0][:batch_size])
                t1 = time.time()
                time_inference_cpu.append(t1-t0)
            time_inference_cpu_mean.append(np.mean(time_inference_cpu))
            time_inference_cpu_std.append(np.std(time_inference_cpu))
            for i in range(20):
                t0 = time.time()
                light_model_gpu(trainset[0][:batch_size].cuda())
                t1 = time.time()
                time_inference_gpu.append(t1-t0)
            time_inference_gpu_mean.append(np.mean(time_inference_gpu))
            time_inference_gpu_std.append(np.std(time_inference_gpu))
    del light_model_cpu, light_model_gpu

    x = [2**i for i in range(7)]    
    plt.figure(figsize=(15,5))
    plt.subplot(1,2,1)
    plt.plot(x, time_inference_cpu_mean, label="CPU")
    plt.fill_between(x,np.array(time_inference_cpu_mean)-np.array(time_inference_cpu_std),
                        np.array(time_inference_cpu_mean)+np.array(time_inference_cpu_std),
                        alpha=0.3)
    affine_approx = np.polyfit(x[3:], time_inference_cpu_mean[3:], 1) 
    plt.plot(x, np.polyval(affine_approx, x), linestyle="--", color="black", label=f"Affine approx {affine_approx[0]:.3e}x")


    plt.xlabel("Batch size")
    plt.ylabel("Inference time (s)")
    plt.legend()
    plt.subplot(1,2,2)
    plt.plot(x, time_inference_gpu_mean, label="GPU")
    plt.fill_between(x,np.array(time_inference_gpu_mean)-np.array(time_inference_gpu_std),
                        np.array(time_inference_gpu_mean)+np.array(time_inference_gpu_std),
                        alpha=0.3)
    affine_approx = np.polyfit(x[3:], time_inference_gpu_mean[3:], 1)
    plt.plot(x, np.polyval(affine_approx, x), linestyle="--", color="black", label=f"Affine approx {affine_approx[0]:.3e}x")
    plt.xlabel("Batch size")
    plt.ylabel("Inference time (s)")
    plt.title("Inference time for LightUNet")
    plt.legend() 
    plt.show()


We can do the same thing with the training time:

In [None]:
if torch.cuda.is_available():
    time_training_cpu_mean = []
    time_training_gpu_mean = []
    time_training_cpu_std = []
    time_training_gpu_std = []
    light_model_cpu = LightUNet(factor=8)
    light_model_cpu.train()
    light_model_gpu = LightUNet(factor=8).cuda()
    light_model_gpu.train()
    criterion = nn.MSELoss()
    optimizer_cpu = torch.optim.Adam(light_model_cpu.parameters(), lr=0.001)
    optimizer_gpu = torch.optim.Adam(light_model_gpu.parameters(), lr=0.001)
    for i in trange(0, 7):
        time_inference_cpu = []
        time_inference_gpu = []
        batch_size = 2**i
        for i in range(2):
            t0 = time.time()
            optimizer_cpu.zero_grad()
            outputs = light_model_cpu(trainset[0][:batch_size])
            loss = criterion(outputs, trainset[1][:batch_size])
            loss.backward()
            optimizer_cpu.step()
            t1 = time.time()
            time_inference_cpu.append(t1-t0)
        time_training_cpu_mean.append(np.mean(time_inference_cpu))
        time_training_cpu_std.append(np.std(time_inference_cpu))
        for i in range(20):
            t0 = time.time()
            optimizer_gpu.zero_grad()
            outputs = light_model_gpu(trainset[0][:batch_size].cuda())
            loss = criterion(outputs, trainset[1][:batch_size].cuda())
            loss.backward()
            optimizer_gpu.step()
            t1 = time.time()
            time_inference_gpu.append(t1-t0)
        time_training_gpu_mean.append(np.mean(time_inference_gpu))
        time_training_gpu_std.append(np.std(time_inference_gpu))
    del light_model_cpu, light_model_gpu

    x = [2**i for i in range(7)]
    plt.figure(figsize=(15,5))
    plt.subplot(1,2,1)
    plt.plot(x, time_training_cpu_mean, label="CPU")
    plt.fill_between(x,np.array(time_training_cpu_mean)-np.array(time_training_cpu_std),
                        np.array(time_training_cpu_mean)+np.array(time_training_cpu_std),
                        alpha=0.3)
    affine_approx = np.polyfit(x[3:], time_training_cpu_mean[3:], 1)
    plt.plot(x, np.polyval(affine_approx, x), linestyle="--", color="black", label=f"Affine approx {affine_approx[0]:.3e}x")
    plt.xlabel("Batch size")
    plt.ylabel("Training time (s)")
    plt.legend()
    plt.subplot(1,2,2)
    plt.plot(x, time_training_gpu_mean, label="GPU")
    plt.fill_between(x,np.array(time_training_gpu_mean)-np.array(time_training_gpu_std),
                        np.array(time_training_gpu_mean)+np.array(time_training_gpu_std),
                        alpha=0.3)
    affine_approx = np.polyfit(x[3:], time_training_gpu_mean[3:], 1)
    plt.plot(x, np.polyval(affine_approx, x), linestyle="--", color="black", label=f"Affine approx {affine_approx[0]:.3e}x")
    plt.xlabel("Batch size")
    plt.ylabel("Training time (s)")
    plt.title("Training time for LightUNet")
    plt.legend()
    plt.show()


<font color='blue'>TODO:</font>
- Compare the training/inference time on CPU and GPU.
- Comment the results.

From now on, the code will be GPU or CPU agnostic. The code will automatically detect if a GPU is available and use it. If not, the code will run on CPU. To do so, we will use the `torch.device` object. 

```python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
input = input.to(device)
```

And every time, we need to manipulate results, we will use the following code: 

```python
output = model(input)
output = output.cpu().detach().numpy()
```
The `detach()` method is used to detach the output from the computational graph (no Grad attached). The `cpu()` method is used to move the output from the GPU to the CPU. The `numpy()` method is used to convert the output to a numpy array.


# 3. Training the U-Net

We will now build functions to train the U-Net. The next cell contains the following functions:
- a function to train the model and evaluate the model





In [63]:
def train(model, trainset, devset=None, nepochs=10, lr0=1e-2, threshold=0, device="cpu"):
    """ Train a model
        Args:
        - model: the model to train
        - trainset: a tuple (X,Y) of the training data
        - devset: a tuple (X,Y) of the development data
        - nepochs: number of epochs
        - lr0: initial learning rate
        - threshold: threshold for the accuracy
        - device: the device to use
        Output:
        - the trained model
    """
    lossfn = nn.BCEWithLogitsLoss()
    optim  = torch.optim.Adam(model.parameters(),lr0)
    losses = torch.zeros(nepochs)
    dlosses = torch.zeros(nepochs)
    accs = torch.zeros(nepochs)
    train, labels = trainset
    dev,  dlabels = devset
    for e in range(nepochs):
        ### train part 
        optim.zero_grad()
        preds = model(train)
        l = lossfn(preds,labels)
        l.backward()
        optim.step()
        losses[e] = l.item()
        accs[e]= ((preds > threshold)*1 == labels).sum()*100/(preds.shape[0]*preds.shape[1]*preds.shape[2]*preds.shape[3]) 
        ### dev eval 
        model.eval()
        preds = model(dev)
        l = lossfn(preds,dlabels)
        dlosses[e] = l.item()
        model.train()
        if e%10==0:
            print(f"Epoch {e} | Loss {losses[e]:.3f} | Acc {accs[e]:.2f} | Dev Loss {dlosses[e]:.3f}")
    print(f"Epoch {e} | Loss {losses[e]:.3f} | Acc {accs[e]:.2f} | Dev Loss {dlosses[e]:.3f}")
    plt.figure(figsize=(10,5))
    plt.plot(losses, label="Train")
    plt.plot(dlosses, label="Dev")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show() 
    return model

    

<font color='blue'>TODO:</font>
- Run the following cell to train the U-Net for 10 epochs (on CPU).


In [None]:
m8 = LightUNet(8)
m8 = train(m8, trainset, testset, 10, lr0=2e-2, threshold=0)

<font color='blue'>TODO:</font>
- Modify the code to use the GPU.
- Train the U-Net for 100 epochs.
- Compare the training time between CPU and GPU.

In [67]:
def train(model, trainset, devset=None, nepochs=10, lr0=1e-2, threshold=0, device="cpu"):
    """ Train a model
        Args:
        - model: the model to train
        - trainset: a tuple (X,Y) of the training data
        - devset: a tuple (X,Y) of the development data
        - nepochs: number of epochs
        - lr0: initial learning rate
        - threshold: threshold for the accuracy
        - device: the device to use
        Output:
        - the trained model
    """
    ### Code to complete here

In [None]:
m8 = LightUNet(8).to(device)
m8 = train(m8, trainset, testset, 100, lr0=2e-2, threshold=0, device=device)

<font color='blue'>TODO:</font>
- Modify the code to return the model with the best validation loss.
- Ideally, show on the plot the best model.

One way to do so, is to save the model for the best validation loss and load it at the end of the training.
To do so, we will use the following code to save the model at a given epoch. 

```python
state_dict = model.state_dict()
torch.save(state_dict, 'model.pth')
```

And the following code to load the model at the end of the training. 

```python
model.load_state_dict(torch.load('model.pth'))
```
When the model is light (like in this case), you can simply track the weight directly using the following code: 

```python
weight = model.state_dict()
model.load_state_dict(weight)
```

In [88]:
def train(model, trainset, devset=None, nepochs=10, lr0=1e-2, threshold=0, device="cpu"):
    """ Train a model
        Args:
        - model: the model to train
        - trainset: a tuple (X,Y) of the training data
        - devset: a tuple (X,Y) of the development data
        - nepochs: number of epochs
        - lr0: initial learning rate
        - threshold: threshold for the accuracy
        - device: the device to use
        Output:
        - the best trained model
    """
    ### Code to complete here
        

In [None]:
m8 = LightUNet(8).to(device)
m8 = train(m8, trainset, testset, 100, lr0=2e-2, threshold=0, device=device)

The following cell contains the code to assess the performance of the model on the test set. Include the following functions:
- a function to plot the results
- a function to plot the PR curve

The PR curve is a good way to assess the performance of a binary classifier. It is a plot of the precision against the recall for the different thresholds. 

The precision is the number of true positive divided by the number of true positive and false positive. Intuitively, it is the ability of the classifier not to label as positive a sample that is negative.

The recall is the number of true positive divided by the number of true positive and false negative. Intuitively, it is the ability of the classifier to find all the positive samples.

In [98]:

def plotres(m, testset, device):
    """ Plot the results of a segmentation model
        Args:
        - m: the model
        - testset: the test set 
    """
    m.eval()
    x,y = testset
    if device == torch.device("cuda"):
        x,y = x.cuda(), y.cuda()
    testid = range(x.shape[0])
    plt.figure(figsize=(5,10))
    for i in testid:
        out = m(x[i].unsqueeze(0)).squeeze().detach()
        plt.subplot(10,4,i*4+1)
        if i == 0:
            plt.title("Input")
        plt.imshow(x[i].cpu().squeeze())
        plt.axis('off')
        plt.subplot(10,4,i*4+2)
        if i == 0:
            plt.title("Output")
        out = m(x[i].unsqueeze(0)).squeeze().detach()
        plt.imshow(out.cpu())
        plt.axis('off')
        plt.subplot(10,4,i*4+3)
        if i == 0:
            plt.title("Output>0")
        plt.imshow((out>0).cpu())
        plt.axis('off')
        plt.subplot(10,4,i*4+4)
        if i == 0:
            plt.title("Ground Truth")
        plt.imshow(y[i].cpu().squeeze()==1)
        plt.axis('off')
    plt.show()
    m.train()


def evalseg(preds,labels,thr=0):
    """ Evaluate a segmentation model
        Args:
        - preds: the predictions
        - labels: the ground truth
        - thr: the threshold
        Output:
        - accuracy, precision, recall
    """
    Tot = preds.numel()
    if preds.ndim == 3 :
        preds = preds.unsqueeze(0)
    out= preds.squeeze().detach()
    selx = out>thr
    sely = (labels == 1).squeeze()
    x1 = (selx*1).sum()
    y1 = (sely*1).sum()
    inters = (selx==sely)
    acc= inters.sum()*100/Tot
    prec= selx[inters].sum()*100/x1
    rapp= sely[inters].sum()*100/y1
    return acc.cpu().item(), prec.cpu().item(), rapp.cpu().item()
    
def plotroc(m, testset, start=-2, end=2, nsteps=50, device="cpu"):
    """ Plot the ROC curve of a segmentation model
        Args:
        - m: the model or a list of models
        - testset: the test set
        - start: the start of the threshold
        - end: the end of the threshold
        - nsteps: the number of steps
        - device: the device
    """
    plt.figure(figsize=(5,5))
    lm = None
    if isinstance(m,list):  
        lm = m
    else: 
        lm = [m]
    idx = 0
    x,y = testset
    if device == torch.device("cuda"):
        x,y = x.cuda(), y.cuda()
    for m in lm:    
        preds = m(x)
        roc = []
        for thr in np.linspace(start,end,nsteps):
            acc, prec, rap = evalseg(preds,y, thr)
            roc.append([prec,rap])
            #print(thr, prec, rap)
        roc = np.array(roc)
        plt.scatter(roc[:,0],roc[:,1], label=str(idx), alpha=0.5)
        idx+=1
    plt.xlabel("Precision")
    plt.ylabel("Recall")
    plt.legend()
    plt.xlim(0,100)
    plt.ylim(0,100)

<font color='blue'>TODO:</font>
- Run the following cell to assess the performance of the model on the test set.

In [None]:
plotres(m8, testset, device)

In [None]:
plotroc(m8, testset, -5, 5, 50, device)

<font color='blue'>TODO:</font>
- Train different U-Net with different number of features (4 and 16).
- Make sure to have the best learning rate for each model by observing the training loss.
- Compare the performance of the different models on the test set.

In [None]:
m4 = LightUNet(4).to(device)
m4 = train(m4,trainset,testset,100, lr0= ... , device=device)
plotres(m4, testset, device)


In [None]:
m16 = LightUNet(16).to(device)
m16 = train(m16,trainset,testset,100, lr0= ... , device=device)
plotres(m16, testset, device)

In [None]:
plotroc([m4,m8,m16], testset, -5, 5, 50, device)


<font color='blue'>TODO:</font>
- Train a larger U-Net with 8 and 16 features for 200 epochs.

In [None]:
bm8 = UNet(8).to(device)
bm8 = train(bm8,trainset,testset,200, lr0= ... , device=device)

In [None]:
plotroc([m8,bm8], testset, -5, 5, 50, device)

In [None]:
bm16 = UNet(16).to(device)
bm16 = train(bm16,trainset,testset,200, lr0= ... , device=device)

In [None]:
plotroc([m16,bm16], testset, -5, 5, 50, device)

In [None]:
bm32 = UNet(32).to(device)
bm32 = train(bm32,trainset,testset,200, lr0= ... , device=device)

In [None]:
plotroc([bm8,bm16,bm32], testset, -5, 5, 50, device)