The author | Samrat Saha compile | source of vitamin k | forward Datas of Science

Supervised Contrastive Learning this paper discusses a lot between Supervised Learning, cross entropy loss and Supervised contrast loss in order to better realize image representation and classification tasks. Let’s take a closer look at the content of this paper.

The paper points out that there can be a 1% improvement in the Image NET dataset.

Architecturally, it is a very simple network, Resnet 50, with a 128-dimensional header. You can add layers if you want.

Code

self.encoder = resnet50()

self.head = nn.Linear(2048.128)

def forward(self, x) :
 feat = self.encoder(x)
 The 128 vector needs to be normalized
 feat = F.normalize(self.head(feat), dim=1)
 return feat
Copy the code

As shown in the figure, the training is carried out in two stages.

  • Training sets using contrast losses (two variations)

  • Freeze the parameters and then use Softmax loss to learn the classifier at the linear layer. (From the practice of the paper)

The above is self-evident.

The main content of this paper is to understand the comparative loss of self-supervision and the comparative loss of supervision.

As can be seen from the SCL (Supervised contrast Loss) diagram above, cats are compared to any non-cats. This means that all cats belong to the same label and are positive pairs, and any non-cats are negative. This is very similar to how triplet data and Triplet Loss work.

Every picture of a cat gets blown up, so even from a picture of a cat, we have a lot of cats.

Monitoring the loss function for comparing losses, while it may seem scary, is actually quite simple.

We’ll see some code later, but first a very simple explanation. Each z is a normalized 128-dimensional vector.

That is to say | | z | | = 1

To restate a fact from linear algebra, if the two vectors u and V are normalized, that means that U.V =cos (the Angle between u and V).

That means that if two normalized vectors are the same, the dot product between them is equal to 1

# Try to understand the following code

import numpy as np
v = np.random.randn(128)
v = v/np.linalg.norm(v)
print(np.dot(v,v))
print(np.linalg.norm(v))
Copy the code

The loss function assumes that each image has an enhanced version, each batch has N images, and the generated batch size = 2*N

In the I! When =j,yi=yj, the molecule exp(z.zj)/tau represents all cats in a batch. Dot I 128th dim vector ZI with all J 128th Dim vectors.

The denominator is I images of cats dotted with other images that are not cats. Take the point zi and zk, such that I! Equals k means it dot times all images except itself.

Finally, we take the logarithmic probability and add it to all cat images in the batch other than itself, then divide by 2* n-1

The total loss of all images

Let’s use some torch code to understand this.

Assuming our batch size is 4, let’s see how to calculate the loss of a single batch.

If the batch size is 4, your input on the network will be 8x3x224x224, where the width and height of the image are 224.

The reason 8=4×2 is that we always have a contrast ratio for each image, so we need to write a data loader accordingly.

Resnet will output an 8×128-dimensional matrix that you can split up to calculate bulk losses.

# batch size
bs = 4
Copy the code

This part right here can compute the numerator

temperature = 0.07

anchor_feature = contrast_feature

anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)
Copy the code

Our characteristic shape is 8×128. Let’s take a 3×128 matrix and transpose. Here is the visualization.

Anchor_feature =3×128 and Contrast_feature =128×3, resulting in 3×3, as shown below

If you notice that all the diagonal elements are points themselves, which we actually don’t want, we’ll delete them.

Linear algebra has the property that if u and v are two vectors, then U.V is greatest when u=v. Therefore, in each row, if we take the maximum contrast of the anchor point and take the same value, all the diagonals will go to 0.

Let’s reduce the dimension from 128 to 2

#bs 1 and DIM 2 mean 2*1x2
features = torch.randn(2.2)

temperature = 0.07 
contrast_feature  = features
anchor_feature = contrast_feature
anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)
print('anchor_dot_contrast=\n{}'.format(anchor_dot_contrast))

logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
print('logits_max = {}'.format(logits_max))
logits = anchor_dot_contrast - logits_max.detach()
print(' logits = {}'.format(logits))

See what happens on the diagonal

anchor_dot_contrast=
tensor([[128.8697, -12.0467], [...12.0467.50.5816]])
 logits_max = tensor([[128.8697],
        [ 50.5816]])
 logits = tensor([[   0.0000, -140.9164], [...62.6283.0.0000]])
Copy the code

Create manual labels and create appropriate masks for comparison calculations. This code is a bit complex, so double check the output.

bs = 4
print('batch size', bs)
temperature = 0.07
labels = torch.randint(4, (1.4))
print('labels', labels)
mask = torch.eq(labels, labels.T).float(a)print('mask = \n{}'.format(logits_mask))

# Hardcode it to make it easier to understand
contrast_count = 2
anchor_count = contrast_count

mask = mask.repeat(anchor_count, contrast_count)

# mask self-contrast
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(bs * anchor_count).view(-1.1),
    0
)
mask = mask * logits_mask
print('mask * logits_mask = \n{}'.format(mask))
Copy the code

Let’s understand the output.

batch size 4
labels tensor([[3.0.2.3]])

# above means we have 3,0,2,3 labels in this batch of 4 varieties of grapes. And just in case you forgot we're only doing a comparison here once so we're going to have 3_c 0_c 2_c 3_c as the comparison in the input batch.

mask = 
tensor([[0..1..1..1..1..1..1..1.],
        [1..0..1..1..1..1..1..1.],
        [1..1..0..1..1..1..1..1.],
        [1..1..1..0..1..1..1..1.],
        [1..1..1..1..0..1..1..1.],
        [1..1..1..1..1..0..1..1.],
        [1..1..1..1..1..1..0..1.],
        [1..1..1..1..1..1..1..0.]])
        
# This is very important, so we create mask = mask * logits_mask, which tells us which image it should be compared to in the 0th image representation.

# so our tag is the tag tensor ([[3,0,2,3]])
# I renamed them to better understand the tensors ([[3_1,0_1,2_1,3_2]])

mask * logits_mask = 
tensor([[0..0..0..1..1..0..0..1.],
        [0..0..0..0..0..1..0..0.],
        [0..0..0..0..0..0..1..0.],
        [1..0..0..0..1..0..0..1.],
        [1..0..0..1..0..0..0..1.],
        [0..1..0..0..0..0..0..0.],
        [0..0..1..0..0..0..0..0.],
        [1..0..0..1..1..0..0..0.]])
Copy the code

Anchor point comparison code

Logits = anchor_dot_contrast -- logits_max.detach()Copy the code

Loss function

Mathematical reviews

We already have the dot product of the first part divided by tau as logits.

The second part of the above equation equals torch. Log (exp_logits.sum(1, keepdim=True))

exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# Calculate the mean of logarithmic likelihood
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# loss
loss = - mean_log_prob_pos

loss = loss.view(anchor_count, 4).mean()
print('19. loss {}'.format(loss))
Copy the code

I think it’s a supervised comparative loss. I think it’s easy to understand the comparative loss of self-supervision now, because it’s simpler than that.

According to our results, the larger contrast_count is, the clearer the model is. Need to change Contrast_count to above 2, hope you can try it with the help of the above instructions.

Refer to the reference

  • [1] : Supervised Contrastive Learning
  • [2] : Florian Schroff, Dmitry Kalenichenko, and James Philbin. Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 815–823, 2015.
  • [3] : A Simple Framework for Contrastive Learning of Visual Representations, Ting Chen, Simon Kornblith Mohammad Norouzi, Geoffrey Hinton
  • [4] : github.com/google-rese…

Original link: towardsdatascience.com/a-detailed-…

Welcome to panchuangai blog: panchuang.net/

Sklearn123.com/

Welcome to docs.panchuang.net/