dlc2action.loss.contrastive
Losses used by contrastive SSL constructors (see dlc2action.ssl.contrastive
)
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Losses used by contrastive SSL constructors (see `dlc2action.ssl.contrastive`) 8""" 9 10from torch import nn 11import torch 12import numpy as np 13from itertools import combinations_with_replacement 14 15 16class _NTXent(nn.Module): 17 """ 18 NT-Xent loss for the contrastive SSL module 19 """ 20 21 def __init__(self, tau: float): 22 """ 23 Parameters 24 ---------- 25 tau : float 26 the tau parameter 27 """ 28 29 super().__init__() 30 self.tau = tau 31 32 def _exp_similarity(self, tensor1, tensor2): 33 """ 34 Compute exponential similarity 35 """ 36 37 s = torch.cosine_similarity(tensor1, tensor2, dim=1) / self.tau 38 s = torch.exp(s) 39 return s 40 41 def _loss(self, similarity, true_indices, denom): 42 """ 43 Compute one of the symmetric components of the loss 44 """ 45 46 l = torch.log(similarity[true_indices] / denom) 47 l = -torch.mean(l) 48 return l 49 50 def forward(self, features1, features2): 51 """ 52 Compute the loss 53 Parameters 54 ---------- 55 features1, features2 : torch.Tensor 56 tensor of shape `(#batch, #features)` 57 Returns 58 ------- 59 loss : float 60 the loss value 61 """ 62 63 indices = list(combinations_with_replacement(list(range(len(features1))), 2)) 64 if len(indices) < 3: 65 return 0 66 indices1, indices2 = map(list, zip(*indices)) 67 true = np.unique(indices1, return_index=True)[1] 68 similarity1 = self._exp_similarity(features1[indices1], features1[indices2]) 69 similarity2 = self._exp_similarity(features2[indices1], features2[indices2]) 70 similarity12 = self._exp_similarity(features1[indices1], features2[indices2]) 71 sum1 = similarity1.sum() - similarity1[true].sum() 72 sum2 = similarity2.sum() - similarity2[true].sum() 73 sum12 = similarity12.sum() 74 loss = self._loss(similarity12, true, sum12 + sum1) + self._loss( 75 similarity12, true, sum12 + sum2 76 ) 77 return loss 78 79 80class _TripletLoss(nn.Module): 81 """ 82 Triplet loss for the pairwise SSL module 83 A slightly modified version: a Softplus function is applied at the last step instead of ReLU to keep 84 the result differentiable 85 """ 86 87 def __init__(self, margin: float = 0, distance: str = "cosine"): 88 """ 89 Parameters 90 ---------- 91 margin : float, default 0 92 the margin parameter 93 distance : {'cosine', 'euclidean'} 94 the distance metric (cosine similarity ot euclidean distance) 95 """ 96 97 super().__init__() 98 self.margin = margin 99 self.nl = nn.Softplus() 100 if distance == "euclidean": 101 self.distance = self._euclidean_distance 102 elif distance == "cosine": 103 self.distance = self._cosine_similarity 104 else: 105 raise ValueError( 106 f'The {distance} is not available, please choose from "euclidean" and "cosine"' 107 ) 108 109 def _euclidean_distance(self, tensor1, tensor2): 110 """ 111 Compute euclidean distance 112 """ 113 114 return torch.sum((tensor1 - tensor2) ** 2, dim=1) 115 116 def _cosine_similarity(self, tensor1, tensor2): 117 """ 118 Compute cosine similarity 119 """ 120 121 return torch.cosine_similarity(tensor1, tensor2, dim=1) 122 123 def forward(self, features1, features2): 124 """ 125 Compute the loss 126 Parameters 127 ---------- 128 features1, features2 : torch.Tensor 129 tensor of shape `(#batch, #features)` 130 Returns 131 ------- 132 loss : float 133 the loss value 134 """ 135 136 negative = torch.cat([features2[1:], features2[:1]]) 137 positive_distance = self.distance(features1, features2) 138 negative_distance = self.distance(features1, negative) 139 loss = torch.mean(self.nl(positive_distance - negative_distance + self.margin)) 140 return loss 141 142 143class _CircleLoss(nn.Module): 144 """ 145 Circle loss for the pairwise SSL module 146 """ 147 148 def __init__(self, gamma: float = 1, margin: float = 0, distance: str = "cosine"): 149 """ 150 Parameters 151 ---------- 152 gamma : float, default 1 153 the gamma parameter 154 margin : float, default 0 155 the margin parameter 156 distance : {'cosine', 'euclidean'} 157 the distance metric (cosine similarity ot euclidean distance) 158 """ 159 160 super().__init__() 161 self.gamma = gamma 162 self.margin = margin 163 if distance == "euclidean": 164 self.distance = self._euclidean_distance 165 elif distance == "cosine": 166 self.distance = self._cosine_similarity 167 else: 168 raise ValueError( 169 f'The {distance} is not available, please choose from "euclidean" and "cosine"' 170 ) 171 172 def _euclidean_distance(self, tensor1, tensor2): 173 """ 174 Compute euclidean distance 175 """ 176 177 return torch.sum((tensor1 - tensor2) ** 2, dim=1) 178 179 def _cosine_similarity(self, tensor1, tensor2): 180 """ 181 Compute cosine similarity 182 """ 183 184 return torch.cosine_similarity(tensor1, tensor2, dim=1) 185 186 def forward(self, features1, features2): 187 """ 188 Compute the loss 189 Parameters 190 ---------- 191 features1, features2 : torch.Tensor 192 tensor of shape `(#batch, #features)` 193 Returns 194 ------- 195 loss : float 196 the loss value 197 """ 198 199 indices = list(combinations_with_replacement(list(range(len(features1))), 2)) 200 indices1, indices2 = map(list, zip(*indices)) 201 true = np.unique(indices1, return_index=True)[1] 202 mask = torch.zeros(len(indices1)).bool() 203 mask[true] = True 204 distances = self.distance(features1[indices1], features2[indices2]) 205 distances[mask] = distances[mask] + self.margin 206 distances[~mask] = distances[~mask] * (-1) 207 distances = torch.exp(self.gamma * distances) 208 loss = torch.log(1 + torch.sum(distances[mask]) * torch.sum(distances[~mask])) 209 return loss