# Exercise 9: Semantic Segmentation with Deep Learning

In the previous exercise, we implemented a full routine for training and validating DL models for image classification.
Now, we shall take it to the next level and perform pixel-wise classification, also known as semantic segmentation.
This is of particular interest to remote sensing, as it allows us to e.g. obtain spatially well-resolved land cover maps, among other products.

The basic ingredients are exactly the same as for image classification (optimiser, loss function, training loop, etc.). We do have some changes to make, though:
* Dataset: this time, we don't just want a single output number (class), but one value per spatial location. Essentially, our dataset should also provide a second image where each pixel has the class index as its value.
* Model: likewise, we need a suitable model that provides spatial outputs rather than a single class vector. You have seen some examples in the lecture. In this exercise we shall use a flavour of the [Hypercolumn](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Hariharan_Hypercolumns_for_Object_2015_CVPR_paper.pdf) to do this job.

## 1. Setup

### 1.1 Install dependencies

Please re-run this as we now need another package (for downloading the data).

In [None]:
import sys
!{sys.executable} -m pip install torch torchvision
!{sys.executable} -m pip install matplotlib
!{sys.executable} -m pip install tqdm                      # this gives us a pretty progress bar

# for downloading files from Google drive
!{sys.executable} -m pip install gdown

### 1.2 Check if GPU available

This is even more important for semantic segmentation, as our models and data tensors are going to be significantly larger.

Run the following code block and proceed if the response is `True`. Else see an instructor.

In [None]:
import torch

print(torch.cuda.is_available())

### 1.3 Random seed

In [None]:
seed = 323444           # the seed value used to initialise the random number generator of PyTorch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

## 2. Dataset

For this exercise we shall be using the [ISPRS Vaihingen](https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-vaihingen/) semantic segmentation dataset.
This is a set of fully-labelled satellite image-segmentation mask pairs, with 9cm resolution and six land cover classes: Impervious, Buildings, Low Vegetation, Tree, Car, Clutter. The images come from a large satellite scene over the town Vaihingen in Germany and were divided into 33 patches, some of which are available with ground truth. These patches are still too large for our model – we would quickly run out of GPU memory if we tried to process an image of e.g. 4000x3000 pixels. Hence, they need to be further divided into even smaller patches. This has already been done for you – all you need to do is to download the image-label pair patches (sized 512x512 pixels) by running the code cell below.

In [None]:
!gdown --id 16OkRr9Ck-XGKy3LjNqGxCTx0pmU1aoUE

!unzip /content/Vaihingen_512_512.zip
data_root = 'Vaihingen_512_512'

Like before, we need to wrap these images in a Dataset class.
It's quite involved this time, however, since we don't just have simple RGB images, but need to collect multiple satellite products. The code below does all of this for you.

In [None]:
import os
import torch
from torch.utils.data import dataset
from torch.utils.data import DataLoader
import torchvision.transforms as T      # transformations that can be used e.g. for data conversion or augmentation
import numpy as np
from PIL import Image


class VaihingenDataset(dataset.Dataset):
    '''
        Custom Dataset class that loads images and ground truth segmentation
        masks from a directory.
    '''

    # image statistics, calculated in advance as averages across the full
    # training data set
    IMAGE_MEANS = (
        (121.03431026287558, 82.52572736507886, 81.92368178210943),     # IR-R-G tiles
        (285.34753853934154)                                            # DSM
    )
    IMAGE_STDS = (
        (54.21029197978022, 38.434924159900554, 37.040640374137475),    # IR-R-G tiles
        (6.485453035150256)                                             # DSM
    )


    # label class names
    LABEL_CLASSES = (
        'Buildings', 'Tree', 'Low Vegetation', 'Clutter', 'Car', 'Impervious'
    )


    def __init__(self, data_root):
        '''
            Dataset class constructor. Here we initialize the dataset instance
            and retrieve file names (and other metadata, if present) for all the
            images and labels (ground truth semantic segmentation maps).
        '''
        super().__init__()

        self.data_root = data_root
        # List all the files in image folder and make a list of samples
        imgs_folder = os.path.join(self.data_root,'imgs')
        samples_list = os.listdir(imgs_folder)
        self.data = []
        for sample in samples_list:
            if 'tif' in sample:
                sample_name = '_'.join(sample.rsplit('_',2)[1:])
            self.data.append(sample_name)


    def __len__(self):
        '''
            This function tells the Data Loader how many images there are in
            this dataset.
        '''
        return len(self.data)

    
    def __getitem__(self, idx):
        '''
            Here's where we load, prepare, and convert the images and
            segmentation mask for the data element at the given "idx".
        '''
        item = self.data[idx]

        # load image
        img_filename = 'top_mosaic_09cm_'+item
        image = Image.open(os.path.join(self.data_root,'imgs',img_filename))
        # load dsm 
        dsm_filename = 'dsm_09cm_matching_'+item
        dsm = Image.open(os.path.join(self.data_root,'dsm',dsm_filename))
        # load segmentation mask (groundtruth)
        labels = Image.open(os.path.join(self.data_root,'gts',img_filename))
        labels = np.array(labels, dtype=np.int64)   # convert to NumPy array temporarily

        # NOTE: at this point it would make sense to perform data augmentation.
        # However, the default augmentations built-in to PyTorch (resp.
        # Torchvision) (i.) only support RGB images; (ii.) only work on the
        # images themselves. In our case, however, we have multispectral data
        # and need to also transform the segmentation mask.
        # This is not difficult to do, but goes beyond the scope of this exercise.
        # For the sake of brevity, we'll leave it out accordingly.
        # What we will have to do, however, is to normalize the image data.
        image = np.array(image, dtype=np.float32)
        image = (image - self.IMAGE_MEANS[0]) / self.IMAGE_STDS[0]
        dsm = np.array(dsm, dtype=np.float32)
        dsm = (dsm - self.IMAGE_MEANS[1]) / self.IMAGE_STDS[1]

        # finally, we need to convert our data into the torch.Tensor format. For
        # the images, we already have a "ToTensor" transform available, but we
        # need to concatenate the images together.
        image = T.ToTensor()(image)
        dsm = T.ToTensor()(dsm)
        #tensors = [T.ToTensor()(i) for i in images]
        inputs = torch.cat([image,dsm], dim=0).float()         # concatenate along spectral dimension and make sure it's in 32-bit floating point

        # For the labels, we need to convert the PIL image to a torch.Tensor.
        labels = torch.from_numpy(labels).long()            # labels need to be in 64-bit integer format

        return inputs, labels





# we also create a function for the data loader here (see Section 2.6 in Exercise 6)
def load_dataloader(batch_size, split='train'):
  return DataLoader(
      VaihingenDataset(os.path.join(data_root, split)),
      batch_size=batch_size,
      shuffle=(split=='train'),       # we shuffle the image order for the training dataset
      num_workers=2                   # perform data loading with two CPU threads
  )

Let's visualise some images.

In [None]:
import os
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

#discrete color scheme
cMap = ListedColormap(['grey', 'darkgreen', 'lawngreen', 'red', 'orange', 'black'])     #  #'Buildings', 'Tree', 'Low Vegetation', 'Clutter', 'Car', 'Impervious'

dataset_train = VaihingenDataset(os.path.join(data_root, 'train'))

# draw a random sample
idx = torch.randint(0, len(dataset_train), (1,))
data, target = dataset_train.__getitem__(idx)
print(f'Image tensor size: {data.size()}')
print(f'Label tensor size: {target.size()}')

# visualise
plt.figure()
plt.imshow(data[:3,...].permute(1,2,0).numpy())     # first three bands: NIR-R-G
plt.title('Input: NIR-R-G satellite imagery')
plt.show()
plt.figure()
plt.imshow(data[3,...].squeeze().numpy())           # band 4: DSM
plt.title('Input: DSM')
plt.show()
fig = plt.figure()
cax = plt.imshow(target.squeeze().numpy(), cmap=cMap)                # target: segmentation mask
cbar = fig.colorbar(cax, ticks=list(range(len(dataset_train.LABEL_CLASSES))))
cbar.ax.set_yticklabels(list(dataset_train.LABEL_CLASSES))
plt.title('Target: segmentation mask')
plt.show()


## 3. Model

Now let's define our semantic segmentation model! Technically, the model must produce an output that is exactly the same size as its input in space. As stated above we shall use a [Hypercolumn](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Hariharan_Hypercolumns_for_Object_2015_CVPR_paper.pdf) for this task. This is a _Fully Convolutional Network_ (FCN), which means that it does not use a single fully-connected layer, but always preserves some notion of space (so it can use convolutions, pooling, ReLU, etc.). The nice property of FCNs is that they can accept any spatial input of sufficient size and can scale up to the limits of the GPU memory.

A Hypercolumn basically performs downsampling via convolutions, poolings, etc., like you have been doing in Exercise 6 for image classification. However, unlike a classifier, it keeps every intermediate output, upsamples (interpolates) them to the original image's size, stacks them together to a large tensor (a hypercolumn) and uses this to perform pixel-wise classification:

![Hypercolumn](https://www.researchgate.net/profile/Devis-Tuia/publication/323273293/figure/fig3/AS:614258178027521@1523461970498/Hypercolumn-based-architecture-used-in-all-our-experiments-Note-that-all-the-layers-are.png)
Image source: Marcos, D., Volpi, M., Kellenberger, B. and Tuia, D., 2018. Land cover mapping at very high resolution with rotation equivariant CNNs: Towards small yet accurate models. ISPRS journal of photogrammetry and remote sensing, 145, pp.96-107.


Let's implement a Hypercolumn with the following architecture:
1. BLOCK 1:
    1. 2D convolution, 32 kernels of size 5x5, stride 4, zero-padding 0
    2. 2D max pool, kernel size 2x2, stride 1
    2. Batch Normalisation
    3. ReLU
2. BLOCK 2:
    1. 2D convolution, 64 kernels of size 5x5, stride 4, zero-padding 0
    2. 2D max pool, kernel size 2x2, stride 1
    2. Batch Normalisation
    3. ReLU
3. BLOCK 3:
    1. 2D convolution, 128 kernels of size 5x5, stride 2, zero-padding 0
    2. 2D max pool, kernel size 2x2, stride 1
    2. Batch Normalisation
    3. ReLU
4. BLOCK 4:
    1. 2D convolution, 256 kernels of size 3x3, stride 1, zero-padding 0
    2. 2D max pool, kernel size 2x2, stride 1
    2. Batch Normalisation
    3. ReLU
5. HYPERCOLUMN: here you do the following:
    1. Take all outputs of the input and BLOCKs 1, 2, 3 and 4 (after the ReLU)
    2. Interpolate them to the original input's spatial size (tip: use an instance of [torch.nn.Upsample](https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html))
    3. Concatenate them together (tip: `torch.cat((tensor1, tensor2, ...), dim=1))
6. FINAL BLOCK: this works on the output of 5. HYPERCOLUMN:
    1. 2D convolution, 256 kernels of size 1x1, stride 1, zero-padding 0
    2. Batch Normalisation
    3. ReLU
    4. 2D convolution, 6 kernels of size 1x1, stride 1, zero-padding 0 (output of model)

In [None]:
import torch.nn as nn


class Hypercolumn(nn.Module):

    def __init__(self):
        super(Hypercolumn, self).__init__()

        #TODO: define your architecture and forward pass here
        # ...
    

    def forward(self, x):
        #TODO
        # ...

Let's test it!

In [None]:
dataloader_train = load_dataloader(batch_size=2, split='train')
model = Hypercolumn()

data, _ = iter(dataloader_train).__next__()
pred = model(data)

assert pred.size(1) == len(dataset_train.LABEL_CLASSES), f'ERROR: invalid number of model output channels (should be # classes {len(dataset_train.LABEL_CLASSES)}, got {pred.size(1)})'
assert pred.size(2) == data.size(2), f'ERROR: invalid spatial height of model output (should be {data.size(2)}, got {pred.size(2)})'
assert pred.size(3) == data.size(3), f'ERROR: invalid spatial width of model output (should be {data.size(3)}, got {pred.size(3)})'

## 4. Model training

All the rest is exactly the same principle as for the image classification part!
Hence, what you can do here is to simply copy-paste all your code cells from Section 4 ("Implement training routine") of the previous exercise.

In [None]:
#TODO: copy-paste criterion block here
criterion = ...

In [None]:
#TODO: copy-paste optimiser block here

In [None]:
#TODO: copy-paste training block here

In [None]:
#TODO: copy-paste validation block here

The rest (Section 5 and later from Exercise 7) is also the same, but we change the model name and parameters a bit here, which is why these code blocks are given to you below.

In [None]:
import glob

os.makedirs('cnn_states/Hypercolumn', exist_ok=True)

def load_model(epoch='latest'):
  model = Hypercolumn()
  modelStates = glob.glob('cnn_states/Hypercolumn/*.pth')
  if len(modelStates) and (epoch == 'latest' or epoch > 0):
    modelStates = [int(m.replace('cnn_states/Hypercolumn/','').replace('.pth', '')) for m in modelStates]
    if epoch == 'latest':
      epoch = max(modelStates)
    stateDict = torch.load(open(f'cnn_states/Hypercolumn/{epoch}.pth', 'rb'), map_location='cpu')
    model.load_state_dict(stateDict)
  else:
    # fresh model
    epoch = 0
  return model, epoch


def save_model(model, epoch):
  torch.save(model.state_dict(), open(f'cnn_states/Hypercolumn/{epoch}.pth', 'wb'))

In [None]:
# define hyperparameters
device = 'cuda'
start_epoch = 0        # set to 0 to start from scratch again or to 'latest' to continue training from saved checkpoint
batch_size = 2
learning_rate = 0.1
weight_decay = 0.001
num_epochs = 10



# initialise data loaders
dl_train = load_dataloader(batch_size, 'train')
dl_val = load_dataloader(batch_size, 'val')

# load model
model, epoch = load_model(epoch=start_epoch)
optim = setup_optimiser(model, learning_rate, weight_decay)

# do epochs
while epoch < num_epochs:

  # training
  model, loss_train, oa_train = train_epoch(dl_train, model, optim, device)

  # validation
  loss_val, oa_val = validate_epoch(dl_val, model, device)

  # print stats
  print('[Ep. {}/{}] Loss train: {:.2f}, val: {:.2f}; OA train: {:.2f}, val: {:.2f}'.format(
      epoch+1, num_epochs,
      loss_train, loss_val,
      100*oa_train, 100*oa_val
  ))

  # save model
  epoch += 1
  save_model(model, epoch)

## 5. Model validation

Like in Exercise 7 we could do a final accuracy evaluation now. We don't have access to the Vaihingen dataset's test image labels, since these are hidden on an official Web evaluation server (the Vaihingen dataset was at some point a contest where people could submit their scores and compete against each other!).

But, we can do something else that we could not do in Exercise 7: visualise our results! Our model provides segmentation masks after all… So let's do this!

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

def visualize(dataLoader, epochs, numImages=5):
  models = [load_model(e)[0] for e in epochs]
  numModels = len(models)
  for idx, (data, labels) in enumerate(dataLoader):
    if idx == numImages:
      break

    _, ax = plt.subplots(nrows=1, ncols=numModels+1, figsize = (20, 15))

    # plot ground truth
    ax[0].imshow(labels[0,...].cpu().numpy())
    ax[0].axis('off')
    if idx == 0:
      ax[0].set_title('Ground Truth')

    for mIdx, model in enumerate(models):
      model = model.to(device)
      with torch.no_grad():
        pred = model(data.to(device))

        # get the label (i.e., the maximum position for each pixel along the class dimension)
        yhat = torch.argmax(pred, dim=1)

        # plot model predictions
        ax[mIdx+1].imshow(yhat[0,...].cpu().numpy())
        ax[mIdx+1].axis('off')
        if idx == 0:
          ax[mIdx+1].set_title(f'Epoch {epochs[mIdx]}')


# visualize predictions for a number of epochs
dl_val_single = load_dataloader(1, 'val')

# load model states at different epochs
epochs = [0, 1, 5, 'latest']                                          #TODO: modify this vector according to your wishes, resp. for how many model states you have trained

visualize(dl_val_single, epochs, numImages=5)