Table of Contents:

Multi-class N-pair Loss

(Sohn et al, Neurips 2016)

ConVIRT

(Zhang et al., 2020)

CLIP

(Radford et al., 2021)

  • Given a batch of N (image, text) pairs, CLIP is trained to predict which of the N × N possible (image, text) pairings across a batch actually occurred.
  • To do this, CLIP learns a multi-modal embedding space by jointly training an image encoder and text encoder to maximize the cosine similarity of the image and text embeddings of the N real pairs in the batch while minimizing the cosine similarity of the embeddings of the \(N^2 - N\) incorrect pairings.

Using CLIP for zero-shot transfer CLIP is pre-trained to predict if an image and a text snippet are paired together in its dataset. For each dataset, we use the names of all the classes in the dataset as the set of potential text pairings and predict the most probable (image, text) pair according to CLIP.

CLIP Architecture. Source: Radford et al., 2021.
  • Use only a linear projection to map from each encoder’s representation to the multi-modal embedding space.

Pseudocode for CLIP. Source: Radford et al., 2021.

We can see how this implements into Pytorch, examining a high-quality open-source implementation on Github, simplified below:

class ClipLoss(nn.Module):

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        labels = torch.arange(num_logits, device=device, dtype=torch.long)
        return labels

    def get_logits(self, image_features, text_features, logit_scale):
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logit_scale * text_features @ image_features.T
        return logits_per_image, logits_per_text

    def forward(self, image_features, text_features, logit_scale, output_dict=False):
        logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)

        labels = self.get_ground_truth(device, logits_per_image.shape[0])

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
        ) / 2

        return {"contrastive_loss": total_loss} if output_dict else total_loss

The softmax temperature is a learned parameter, implemented

class CLIP(nn.Module):
    def __init__(self, ...) -> None:
        ...
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

References

  1. Kihyuk Sohn. Improved Deep Metric Learning with Multi-class N-pair Loss Objective. Neurips, 2016. [PDF].

  2. Yuhao Zhang, Hang Jiang, Yasuhide Miura, Christopher D. Manning, Curtis P. Langlotz. Contrastive Learning of Medical Visual Representations from Paired Images and Text. [PDF].

  3. Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. Learning Transferable Visual Models From Natural Language Supervision. [PDF].