Training for image segmentation

Image segmentation is the task of predicting a class for every pixel in an image. This allows to very finely delimitates objects and shapes of many classes from within images, at once.

Data format

  • Image segmentation image & mask example

Example image Example image

  • Image segmentation main image list format from /opt/platform/examples/cityscapes/train.txt:
/opt/platform/examples/cityscapes/train/imgs//konstanz_000000_001391_leftImg8bit.png /opt/platform/examples/cityscapes/train/annot/konstanz_000000_001391_gtCoarse_labelTrainIds.png
/opt/platform/examples/cityscapes/train/imgs//dortmund_000000_000053_leftImg8bit.png /opt/platform/examples/cityscapes/train/annot/dortmund_000000_000053_gtCoarse_labelTrainIds.png
/opt/platform/examples/cityscapes/train/imgs//heidelberg_000000_001065_leftImg8bit.png /opt/platform/examples/cityscapes/train/annot/heidelberg_000000_001065_gtCoarse_labelTrainIds.png
/opt/platform/examples/cityscapes/train/imgs//berlin_000170_000019_leftImg8bit.png /opt/platform/examples/cityscapes/train/annot/berlin_000170_000019_gtFine_labelTrainIds.png

We suggest to organize the files as follows:

your_data/train
your_data/train/train.txt
your_data/train/annot
your_data/train/annot/annot_img1.png
your_data/train/annot/annot_img2.png
...
your_data/train/imgs
your_data/train/imgs/img1.jpg
your_data/train/imgs/img2.jpg
...

your_data/test
your_data/test/test.txt
your_data/test/annot
your_data/test/annot/annot_img100.png
your_data/test/annot/annot_img200.png
...
your_data/test/imgs
your_data/test/imgs/img100.jpg
your_data/test/imgs/img200.jpg
...

The DD platform has the following requirements for training from images for segmentation:

  • All data must be in image format, most encoding supported (e.g. png, jpg, …)
  • For every image, there’s a mask describing the class of every pixel of the image, itself in the form of an 8-bit 1-channel image. In other words, the mask must be a black and white image (single channel) with values from 0 to 255 maximum. Each value represents a class, so for a two classes model (background and an object type), pixels can only be 0 or 1. For a two classes segmentation problem, pixels can only be 0, 1 or 2, etc… See examples on the right
  • A main text file lists all image paths and their image mask counterpart, using space as a separator. See on the right for data format and example.

  • You need to prepare both a train.txt and test.txt file for training and testing purposes.

DD platform comes with a custom Jupyter UI that allows testing your image segmentation dataset prior to training:

Image segmentation data check in DD platform Jupyter UI

Training an image segmenter

Using the DD platform, from a JupyterLab notebook, start from the code on the right.

Image segmentation notebook snippet:

city_seg = Segmentation(
  'city_psp',
  training_repo='/opt/platform/examples/cityscapes/train.txt',
  testing_repo='/opt/platform/examples/cityscapes/test_shuf50.txt',
  host='deepdetect_training',
  port=8080,
  img_height=480,
  img_width=480,
  model_repo='/opt/platform/models/training/beniz/cityscapes/',
  nclasses=8,
  template='pspnet_vgg16',
  iterations=75000,
  test_interval=1000,
  snapshot_interval=1000,
  batch_size=6,
  test_batch_size=1,
  noise_prob=0.001,
  distort_prob=0.001,
  gpuid=0,
  base_lr=0.001,
  weights='/opt/platform/models/pretrained/pspnet_vgg16/vgg16_init_deeplab.caffemodel',
  solver_type="AMSGRAD",
  finetune=True,
  rotate=False,
  mirror=True,
  resume=False,
  loss='dice_weighted'
  )#.run()
city_seg

This prepares for training an image segmentere with the following parameters:

  • city_psp is the example job name
  • training_repo specifies the location of the data
  • template specifies a PSPNet with VGG-16 basis that has state of the art performances.

  • loss specifies the loss to be used to train the model. For two-classes problems, dice works best. For multi-class problems, weigher softmax or dice_weighted are recommended. When using softmax loss, see how to use class_weights from the main API documentation at https://deepdetect.com/api/#launch-a-training-job at it improves training of segmentation tasks that in general have very unbalanced classes.

  • img_width and img_height specify the input size of the image, see the recommended models section to adapt to other architectures available.

  • noise_prob and distort_prob control the random occurence of tens of data augmentation schemes. 0.001 is usually a good value for both when data is scarce.

  • mirror activates mirroring of inputs as data augmentation for both the input image and the mask

  • rotate activates rotation of inputs as data augmentation for both the input image and the mask (e.g. useful for satellite images, …)

  • finetune automatically prepares the network architecture for finetuning

  • weights specifies the pre-trained model weights to start training from

  • solver_type specifies the optimizer, see https://deepdetect.com/api/#launch-a-training-job and solver_type for the many options

  • base_lr specifies the learning rate. For finetuning mask detection models, 1e-3 works well.

  • gpuid specifies which GPU to use, starting with number 0

The platform has many neural network architectures and pre-trained models built-in for image segmentation. These range from state of the art architectures like PSPNet and Deeplab for state of the art, U-Net for flexibility and SE-Net for low-memory and embedded tasks.

Below is a list of recommended models for image classification from which to best choose for your task.

Model Template Image size Pre-Trained (/opt/platform/models/pretrained) Recommendation
PSPNet 480x480 pspnet_vgg16 pspnet_vgg16/vgg16_init_deeplab.caffemodel Fast / Excellent accuracy /desktops
Deeplab 480x480 deeplab_vgg16 deeplab_vgg16/vgg16_init_deeplab.caffemodel Fast / Excellent accuracy / desktops
U-Net 480x480 unet none, from scratch Flexible / Very good accuracy / embedded & desktops
E-Net 480x480 enet none, from scratch Fast / Average accuracy / embedded
SE-Net 224x224 se_net none, from scratch Extremely fast / Average accuracy / embedded

Related