Contrastive Learning
Table of Contents:
Multi-class N-pair Loss
ConVIRT
CLIP
- Given a batch of N (image, text) pairs, CLIP is trained to predict which of the N × N possible (image, text) pairings across a batch actually occurred.
- To do this, CLIP learns a multi-modal embedding space by jointly training an image encoder and text encoder to maximize the cosine similarity of the image and text embeddings of the N real pairs in the batch while minimizing the cosine similarity of the embeddings of the \(N^2 - N\) incorrect pairings.
Using CLIP for zero-shot transfer CLIP is pre-trained to predict if an image and a text snippet are paired together in its dataset. For each dataset, we use the names of all the classes in the dataset as the set of potential text pairings and predict the most probable (image, text) pair according to CLIP.
- Use only a linear projection to map from each encoder’s representation to the multi-modal embedding space.
We can see how this implements into Pytorch, examining a high-quality open-source implementation on Github, simplified below:
class ClipLoss(nn.Module):
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
return labels
def get_logits(self, image_features, text_features, logit_scale):
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
return logits_per_image, logits_per_text
def forward(self, image_features, text_features, logit_scale, output_dict=False):
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
labels = self.get_ground_truth(device, logits_per_image.shape[0])
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return {"contrastive_loss": total_loss} if output_dict else total_loss
The softmax temperature is a learned parameter, implemented
class CLIP(nn.Module):
def __init__(self, ...) -> None:
...
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
References
-
Kihyuk Sohn. Improved Deep Metric Learning with Multi-class N-pair Loss Objective. Neurips, 2016. [PDF].
-
Yuhao Zhang, Hang Jiang, Yasuhide Miura, Christopher D. Manning, Curtis P. Langlotz. Contrastive Learning of Medical Visual Representations from Paired Images and Text. [PDF].
-
Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. Learning Transferable Visual Models From Natural Language Supervision. [PDF].