Implementing contrastive loss and triplet loss in Tensorflow
Asked Answered
R

3

40

I started to play with TensorFlow two days ago and I'm wondering if there is the triplet and the contrastive losses implemented.

I've been looking at the documentation, but I haven't found any example or description about these things.

Ropable answered 8/7, 2016 at 6:20 Comment(0)
S
86

Update (2018/03/19): I wrote a blog post detailing how to implement triplet loss in TensorFlow.


You need to implement yourself the contrastive loss or the triplet loss, but once you know the pairs or triplets this is quite easy.


Contrastive Loss

Suppose you have as input the pairs of data and their label (positive or negative, i.e. same class or different class). For instance you have images as input of size 28x28x1:

left = tf.placeholder(tf.float32, [None, 28, 28, 1])
right = tf.placeholder(tf.float32, [None, 28, 28, 1])
label = tf.placeholder(tf.int32, [None, 1]). # 0 if same, 1 if different
margin = 0.2

left_output = model(left)  # shape [None, 128]
right_output = model(right)  # shape [None, 128]

d = tf.reduce_sum(tf.square(left_output - right_output), 1)
d_sqrt = tf.sqrt(d)

loss = label * tf.square(tf.maximum(0., margin - d_sqrt)) + (1 - label) * d

loss = 0.5 * tf.reduce_mean(loss)

Triplet Loss

Same as with contrastive loss, but with triplets (anchor, positive, negative). You don't need labels here.

anchor_output = ...  # shape [None, 128]
positive_output = ...  # shape [None, 128]
negative_output = ...  # shape [None, 128]

d_pos = tf.reduce_sum(tf.square(anchor_output - positive_output), 1)
d_neg = tf.reduce_sum(tf.square(anchor_output - negative_output), 1)

loss = tf.maximum(0., margin + d_pos - d_neg)
loss = tf.reduce_mean(loss)

The real trouble when implementing triplet loss or contrastive loss in TensorFlow is how to sample the triplets or pairs. I will focus on generating triplets because it is harder than generating pairs.

The easiest way is to generate them outside of the Tensorflow graph, i.e. in python and feed them to the network through the placeholders. Basically you select images 3 at a time, with the first two from the same class and the third from another class. We then perform a feedforward on these triplets, and compute the triplet loss.

The issue here is that generating triplets is complicated. We want them to be valid triplets, triplets with a positive loss (otherwise the loss is 0 and the network doesn't learn).
To know whether a triplet is good or not you need to compute its loss, so you already make one feedforward through the network...

Clearly, implementing triplet loss in Tensorflow is hard, and there are ways to make it more efficient than sampling in python but explaining them would require a whole blog post !

Such answered 8/7, 2016 at 15:23 Comment(13)
Hi @Olivier, I am very interested in the sampling part. Would you or have you posted a blog for it? I am doing what just as you said, to feed forward once, and compute the losses for all possible triplets, filter out invalid ones, and sample a batch to do another forward+backward...Windstorm
Didn't write any blog post. One key insight is to compute all the possible triplets as explained in OpenFace, my answer above contains the old solution. To remove the middle sess.run() call, you can add a tf.py_func operation inside the graph to filter out the bad triplets.Such
@weitang114: Another way for the 2nd part is to just compute the loss for all the triplets, removing only the invalid triplets (i.e. (+, +, +)), which can be computed in advance. This converges well, surprisingly.Such
thank you for this advice. I didn't get the idea that moment, but found it very useful recently. This process implemented in tf helped me reduce a training time from 5 days to 1 day. :)Windstorm
@weitang114: Yeah it's very convenient. Did you implement it without tf.py_func (the second idea I gave)?Such
No. I implemented it almost the same as what is said in the OpenFace article. I used tf.nn.relu() to filter out useless losses, and count how many losses are left, say C, then the mean loss is sum(losses)/C.Windstorm
@Windstorm how did you manage to select only the valid triplets for training? When I verify if the loss is > 0 for a set of triplets (anchor image, positive image, negative image), I have to feed the triplets again to the model to calculate the gradients. And because I use dropout, the same triplets might give loss 0 at the next feed. I'm stuck.Corr
@OlivierMoindrot can you please provide an example of filtering bad triplets using py_func?Corr
Why d is used instead of sqrt_d at the ending of first assignment to loss in contrastive loss?Quiteria
This is the formula for contrastive lossSuch
@HelloLili: I finally wrote that blog post. Here it is: omoindrot.github.io/triplet-lossSuch
@weitang114: I finally wrote that blog post. Here it is: omoindrot.github.io/triplet-lossSuch
The aforementioned code of Contrastive Loss should be modified a little bit to avoid NaN error. i.e. d_sqrt = tf.sqrt(d + 1e-7). We used the code and found the bug.Stonedead
P
14

Triplet loss with semihard negative mining is now implemented in tf.contrib, as follows:

triplet_semihard_loss(
    labels,
    embeddings,
    margin=1.0
)

where:

Args:

  • labels: 1-D tf.int32 Tensor with shape [batch_size] of multiclass integer labels.

  • embeddings: 2-D float Tensor of embedding vectors.Embeddings should be l2 normalized.

  • margin: Float, margin term in theloss definition.

Returns:

  • triplet_loss: tf.float32 scalar.

For further information, check the link bellow:

https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/losses/metric_learning/triplet_semihard_loss

Platinize answered 10/1, 2018 at 1:49 Comment(2)
Link only answers? Include some relevant portions from the link here.Whereat
While this link might provide some limited, immediate help, an answer should include sufficient context around the link so your fellow users will have some idea what it is and why it’s there. Always quote the most relevant part of an important link, to make it more useful to future readers with other, similar questions. In addition, other users tend to respond negatively to answers which are barely more than a link to an external site, and they might be deleted.Costanza
G
6

Tiago, I don't think you are using the same formula Olivier gave. Here is the right code (not sure it will work though, just fixing the formula) :

def compute_euclidean_distance(x, y):
    """
    Computes the euclidean distance between two tensorflow variables
    """

    d = tf.reduce_sum(tf.square(tf.sub(x, y)),1)
    return d


def compute_contrastive_loss(left_feature, right_feature, label, margin):

    """
    Compute the contrastive loss as in


    L = 0.5 * Y * D^2 + 0.5 * (Y-1) * {max(0, margin - D)}^2

    **Parameters**
     left_feature: First element of the pair
     right_feature: Second element of the pair
     label: Label of the pair (0 or 1)
     margin: Contrastive margin

    **Returns**
     Return the loss operation

    """

    label = tf.to_float(label)
    one = tf.constant(1.0)

    d = compute_euclidean_distance(left_feature, right_feature)
    d_sqrt = tf.sqrt(compute_euclidean_distance(left_feature, right_feature))
    first_part = tf.mul(one-label, d)# (Y-1)*(d)

    max_part = tf.square(tf.maximum(margin-d_sqrt, 0))
    second_part = tf.mul(label, max_part)  # (Y) * max(margin - d, 0)

    loss = 0.5 * tf.reduce_mean(first_part + second_part)

    return loss
Gleeman answered 16/7, 2016 at 23:51 Comment(8)
Hi Wasssim, thanks for the fix, just a patch in your code. d_sqrt = tf.sqrt(compute_euclidean_distance(left_feature, right_feature)) But even with this fix, I get very low accuracy (but the loss decreases as expected).Ropable
@TiagoFreitasPereira I am having the same problem with my triplet loss implementation. I will notify you if I find a solution...Gleeman
Hey @Wassim, thanks. If it is easier, you can try to bootstrap my project (github.com/tiagofrepereira2012/examples.tensorflow).Ropable
@TiagoFreitasPereira , it seems like it has to do with the way we implement the accuracy computation. Looks like when using Triplet Loss or Contrastive Loss you can't compute accuracy using label verification (because the network wasn't trained to differentiate the 10 classes), however, you have to compute accuracy by evaluating whether the network guessed that two elements are from the same class or not.Gleeman
See section 4 and 5.6 of this paper arxiv.org/pdf/1503.03832v3.pdfGleeman
Hi @Wassim, yes, I understand that, but my goal here is to train the siamese net (or the triplet) and use one of the fully connected layers (fc1 or fc2 in my code) as features. In our example, since the network is good to diferenciate digits, the trained features must be good.Ropable
The features will be good but you need to add a final softmax (and retrain it) on themSuch
The trained features are indeed good but after you apply a softmax you have to beware of the indices because since we don't feed labels, an activation of the first neuron on the softmax layer doesn't necessary signify that it detected a 0, it could be any of the other digits. Usually, if you train using contrastive/triplet loss you're aiming to use the network for comparison rather than classification.Gleeman

© 2022 - 2024 — McMap. All rights reserved.