PyTorch implementation of siamese and triplet networks for learning embeddings.
Siamese and triplet networks are useful to learn mappings from image to a compact Euclidean space where distances correspond to a measure of similarity [2]. Embeddings trained in such way can be used as features vectors for classification or few-shot learning tasks.
Requires pytorch 0.4 with torchvision 0.2.1
For pytorch 0.3 compatibility checkout tag torch-0.3.1
We'll train embeddings on MNIST dataset. Experiments were run in jupyter notebook.
We'll go through learning supervised feature embeddings using different loss functions on MNIST dataset. This is just for visualization purposes, thus we'll be using 2-dimensional embeddings which isn't the best choice in practice.
For every experiment the same embedding network is used (32 conv 5x5 -> PReLU -> MaxPool 2x2 -> 64 conv 5x5 -> PReLU -> MaxPool 2x2 -> Dense 256 -> PReLU -> Dense 256 -> PReLU -> Dense 2) and we don't perform any hyperparameter search.
We add a fully-connected layer with the number of classes and train the network for classification with softmax and cross-entropy. The network trains to ~99% accuracy. We extract 2 dimensional embeddings from penultimate layer:
Train set:
Test set:
While the embeddings look separable (which is what we trained them for), they don't have good metric properties. They might not be the best choice as a descriptor for new classes.
Now we'll train a siamese network that takes a pair of images and trains the embeddings so that the distance between them is minimized if they're from the same class and is greater than some margin value if they represent different classes. We'll minimize a contrastive loss function [1]:
SiameseMNIST class samples random positive and negative pairs that are then fed to Siamese Network.
After 20 epochs of training here are the embeddings we get for training set:
Test set:
The learned embeddings are clustered much better within class.
We'll train a triplet network, that takes an anchor, a positive (of same class as an anchor) and negative (of different class than an anchor) examples. The objective is to learn embeddings such that the anchor is closer to the positive example than it is to the negative example by some margin value.
Source: Schroff, Florian, Dmitry Kalenichenko, and James Philbin. Facenet: A unified embedding for face recognition and clustering. CVPR 2015.
Triplet loss:
TripletMNIST class samples a positive and negative example for every possible anchor.
After 20 epochs of training here are the embeddings we get for training set:
Test set:
The learned embeddings are not as close to each other within class as in case of siamese network, but that's not what we optimized them for. We wanted the embeddings to be closer to other embeddings from the same class than from the other classes and we can see that's where the training is going to.
There are couple of problems with siamese and triplet networks:
To deal with these issues efficiently, we'll feed a network with standard mini-batches as we did for classification. The loss function will be responsible for selection of hard pairs and triplets within mini-batch. If we feed the network with 16 images per 10 classes, we can process up to 159*160/2 = 12720 pairs and 10*16*15/2*(9*16) = 172800 triplets, compared to 80 pairs and 53 triplets in previous implementation.
Usually it's not the best idea to process all possible pairs or triplets within a mini-batch. We can find some strategies on how to select triplets in [2] and [3].
We'll feed a network with mini-batches, as we did for classification network. This time we'll use a special BatchSampler that will sample n_classes and n_samples within each class, resulting in mini batches of size n_classes*n_samples.
For each mini batch positive and negative pairs will be selected using provided labels.
MNIST is a rather easy dataset and the embeddings from the randomly selected pairs were quite good already, we don't see much improvement here.
Train embeddings:
Test embeddings:
We'll feed a network with mini-batches just like with online pair selection. There are couple of strategies we can use for triplet selection given labels and predicted embeddings:
The strategy for triplet selection must be chosen carefully. A bad strategy might lead to inefficient training or, even worse, to model collapsing (all embeddings ending up having the same values).
Here's what we got with random hard negatives for each positive pair.
Training set:
Test set:
Similar experiments were conducted for FashionMNIST dataset where advantages of online negative mining are slightly more visible. The exact same network architecture with only 2-dimensional embeddings was used, which is probably not complex enough for learning good embeddings. More complex datasets with higher number classses should benefit even more from online mining.
Siamese network with randomly selected pairs
Online contrastive loss with negative mining
Triplet network with random triplets
Online triplet loss with negative mining
[1] Raia Hadsell, Sumit Chopra, Yann LeCun, Dimensionality reduction by learning an invariant mapping, CVPR 2006
[2] Schroff, Florian, Dmitry Kalenichenko, and James Philbin. Facenet: A unified embedding for face recognition and clustering. CVPR 2015
[3] Alexander Hermans, Lucas Beyer, Bastian Leibe, In Defense of the Triplet Loss for Person Re-Identification, 2017
[4] Brandon Amos, Bartosz Ludwiczuk, Mahadev Satyanarayanan, OpenFace: A general-purpose face recognition library with mobile applications, 2016
[5] Yi Sun, Xiaogang Wang, Xiaoou Tang, Deep Learning Face Representation by Joint Identification-Verification, NIPS 2014