dlc2action.loss.contrastive
Losses used by contrastive SSL constructors (see dlc2action.ssl.contrastive).
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Losses used by contrastive SSL constructors (see `dlc2action.ssl.contrastive`).""" 8 9from itertools import combinations_with_replacement 10 11import numpy as np 12import torch 13from torch import nn 14 15 16class NTXent(nn.Module): 17 """NT-Xent loss for the contrastive SSL module.""" 18 19 def __init__(self, tau: float): 20 """Initialize the loss. 21 22 Parameters 23 ---------- 24 tau : float 25 the tau parameter 26 27 """ 28 super().__init__() 29 self.tau = tau 30 31 def _exp_similarity(self, tensor1, tensor2): 32 """Compute exponential similarity.""" 33 s = torch.cosine_similarity(tensor1, tensor2, dim=1) / self.tau 34 s = torch.exp(s) 35 return s 36 37 def _loss(self, similarity, true_indices, denom): 38 """Compute one of the symmetric components of the loss.""" 39 l = torch.log(similarity[true_indices] / denom) 40 l = -torch.mean(l) 41 return l 42 43 def forward(self, features1, features2): 44 """Compute the loss. 45 46 Parameters 47 ---------- 48 features1, features2 : torch.Tensor 49 tensor of shape `(#batch, #features)` 50 Returns 51 ------- 52 loss : float 53 the loss value 54 55 """ 56 indices = list(combinations_with_replacement(list(range(len(features1))), 2)) 57 if len(indices) < 3: 58 return 0 59 indices1, indices2 = map(list, zip(*indices)) 60 true = np.unique(indices1, return_index=True)[1] 61 similarity1 = self._exp_similarity(features1[indices1], features1[indices2]) 62 similarity2 = self._exp_similarity(features2[indices1], features2[indices2]) 63 similarity12 = self._exp_similarity(features1[indices1], features2[indices2]) 64 sum1 = similarity1.sum() - similarity1[true].sum() 65 sum2 = similarity2.sum() - similarity2[true].sum() 66 sum12 = similarity12.sum() 67 loss = self._loss(similarity12, true, sum12 + sum1) + self._loss( 68 similarity12, true, sum12 + sum2 69 ) 70 return loss 71 72 73class TripletLoss(nn.Module): 74 """Triplet loss for the pairwise SSL module. 75 76 A slightly modified version: a Softplus function is applied at the last step instead of ReLU to keep 77 the result differentiable 78 """ 79 80 def __init__(self, margin: float = 0, distance: str = "cosine"): 81 """Initialize the loss. 82 83 Parameters 84 ---------- 85 margin : float, default 0 86 the margin parameter 87 distance : {'cosine', 'euclidean'} 88 the distance metric (cosine similarity or euclidean distance) 89 90 """ 91 super().__init__() 92 self.margin = margin 93 self.nl = nn.Softplus() 94 if distance == "euclidean": 95 self.distance = self._euclidean_distance 96 elif distance == "cosine": 97 self.distance = self._cosine_similarity 98 else: 99 raise ValueError( 100 f'The {distance} is not available, please choose from "euclidean" and "cosine"' 101 ) 102 103 def _euclidean_distance(self, tensor1, tensor2): 104 """Compute euclidean distance.""" 105 return torch.sum((tensor1 - tensor2) ** 2, dim=1) 106 107 def _cosine_similarity(self, tensor1, tensor2): 108 """Compute cosine similarity.""" 109 return torch.cosine_similarity(tensor1, tensor2, dim=1) 110 111 def forward(self, features1, features2): 112 """Compute the loss. 113 114 Parameters 115 ---------- 116 features1, features2 : torch.Tensor 117 tensor of shape `(#batch, #features)` 118 Returns 119 ------- 120 loss : float 121 the loss value 122 123 """ 124 negative = torch.cat([features2[1:], features2[:1]]) 125 positive_distance = self.distance(features1, features2) 126 negative_distance = self.distance(features1, negative) 127 loss = torch.mean(self.nl(positive_distance - negative_distance + self.margin)) 128 return loss 129 130 131class CircleLoss(nn.Module): 132 """Circle loss for the pairwise SSL module.""" 133 134 def __init__(self, gamma: float = 1, margin: float = 0, distance: str = "cosine"): 135 """Initialize the loss. 136 137 Parameters 138 ---------- 139 gamma : float, default 1 140 the gamma parameter 141 margin : float, default 0 142 the margin parameter 143 distance : {'cosine', 'euclidean'} 144 the distance metric (cosine similarity or euclidean distance) 145 146 """ 147 super().__init__() 148 self.gamma = gamma 149 self.margin = margin 150 if distance == "euclidean": 151 self.distance = self._euclidean_distance 152 elif distance == "cosine": 153 self.distance = self._cosine_similarity 154 else: 155 raise ValueError( 156 f'The {distance} is not available, please choose from "euclidean" and "cosine"' 157 ) 158 159 def _euclidean_distance(self, tensor1, tensor2): 160 """Compute euclidean distance.""" 161 return torch.sum((tensor1 - tensor2) ** 2, dim=1) 162 163 def _cosine_similarity(self, tensor1, tensor2): 164 """Compute cosine similarity.""" 165 return torch.cosine_similarity(tensor1, tensor2, dim=1) 166 167 def forward(self, features1, features2): 168 """Compute the loss. 169 170 Parameters 171 ---------- 172 features1, features2 : torch.Tensor 173 tensor of shape `(#batch, #features)` 174 175 Returns 176 ------- 177 loss : float 178 the loss value 179 180 """ 181 indices = list(combinations_with_replacement(list(range(len(features1))), 2)) 182 indices1, indices2 = map(list, zip(*indices)) 183 true = np.unique(indices1, return_index=True)[1] 184 mask = torch.zeros(len(indices1)).bool() 185 mask[true] = True 186 distances = self.distance(features1[indices1], features2[indices2]) 187 distances[mask] = distances[mask] + self.margin 188 distances[~mask] = distances[~mask] * (-1) 189 distances = torch.exp(self.gamma * distances) 190 loss = torch.log(1 + torch.sum(distances[mask]) * torch.sum(distances[~mask])) 191 return loss
17class NTXent(nn.Module): 18 """NT-Xent loss for the contrastive SSL module.""" 19 20 def __init__(self, tau: float): 21 """Initialize the loss. 22 23 Parameters 24 ---------- 25 tau : float 26 the tau parameter 27 28 """ 29 super().__init__() 30 self.tau = tau 31 32 def _exp_similarity(self, tensor1, tensor2): 33 """Compute exponential similarity.""" 34 s = torch.cosine_similarity(tensor1, tensor2, dim=1) / self.tau 35 s = torch.exp(s) 36 return s 37 38 def _loss(self, similarity, true_indices, denom): 39 """Compute one of the symmetric components of the loss.""" 40 l = torch.log(similarity[true_indices] / denom) 41 l = -torch.mean(l) 42 return l 43 44 def forward(self, features1, features2): 45 """Compute the loss. 46 47 Parameters 48 ---------- 49 features1, features2 : torch.Tensor 50 tensor of shape `(#batch, #features)` 51 Returns 52 ------- 53 loss : float 54 the loss value 55 56 """ 57 indices = list(combinations_with_replacement(list(range(len(features1))), 2)) 58 if len(indices) < 3: 59 return 0 60 indices1, indices2 = map(list, zip(*indices)) 61 true = np.unique(indices1, return_index=True)[1] 62 similarity1 = self._exp_similarity(features1[indices1], features1[indices2]) 63 similarity2 = self._exp_similarity(features2[indices1], features2[indices2]) 64 similarity12 = self._exp_similarity(features1[indices1], features2[indices2]) 65 sum1 = similarity1.sum() - similarity1[true].sum() 66 sum2 = similarity2.sum() - similarity2[true].sum() 67 sum12 = similarity12.sum() 68 loss = self._loss(similarity12, true, sum12 + sum1) + self._loss( 69 similarity12, true, sum12 + sum2 70 ) 71 return loss
NT-Xent loss for the contrastive SSL module.
20 def __init__(self, tau: float): 21 """Initialize the loss. 22 23 Parameters 24 ---------- 25 tau : float 26 the tau parameter 27 28 """ 29 super().__init__() 30 self.tau = tau
Initialize the loss.
Parameters
tau : float the tau parameter
44 def forward(self, features1, features2): 45 """Compute the loss. 46 47 Parameters 48 ---------- 49 features1, features2 : torch.Tensor 50 tensor of shape `(#batch, #features)` 51 Returns 52 ------- 53 loss : float 54 the loss value 55 56 """ 57 indices = list(combinations_with_replacement(list(range(len(features1))), 2)) 58 if len(indices) < 3: 59 return 0 60 indices1, indices2 = map(list, zip(*indices)) 61 true = np.unique(indices1, return_index=True)[1] 62 similarity1 = self._exp_similarity(features1[indices1], features1[indices2]) 63 similarity2 = self._exp_similarity(features2[indices1], features2[indices2]) 64 similarity12 = self._exp_similarity(features1[indices1], features2[indices2]) 65 sum1 = similarity1.sum() - similarity1[true].sum() 66 sum2 = similarity2.sum() - similarity2[true].sum() 67 sum12 = similarity12.sum() 68 loss = self._loss(similarity12, true, sum12 + sum1) + self._loss( 69 similarity12, true, sum12 + sum2 70 ) 71 return loss
Compute the loss.
Parameters
features1, features2 : torch.Tensor
tensor of shape (#batch, #features)
Returns
loss : float the loss value
74class TripletLoss(nn.Module): 75 """Triplet loss for the pairwise SSL module. 76 77 A slightly modified version: a Softplus function is applied at the last step instead of ReLU to keep 78 the result differentiable 79 """ 80 81 def __init__(self, margin: float = 0, distance: str = "cosine"): 82 """Initialize the loss. 83 84 Parameters 85 ---------- 86 margin : float, default 0 87 the margin parameter 88 distance : {'cosine', 'euclidean'} 89 the distance metric (cosine similarity or euclidean distance) 90 91 """ 92 super().__init__() 93 self.margin = margin 94 self.nl = nn.Softplus() 95 if distance == "euclidean": 96 self.distance = self._euclidean_distance 97 elif distance == "cosine": 98 self.distance = self._cosine_similarity 99 else: 100 raise ValueError( 101 f'The {distance} is not available, please choose from "euclidean" and "cosine"' 102 ) 103 104 def _euclidean_distance(self, tensor1, tensor2): 105 """Compute euclidean distance.""" 106 return torch.sum((tensor1 - tensor2) ** 2, dim=1) 107 108 def _cosine_similarity(self, tensor1, tensor2): 109 """Compute cosine similarity.""" 110 return torch.cosine_similarity(tensor1, tensor2, dim=1) 111 112 def forward(self, features1, features2): 113 """Compute the loss. 114 115 Parameters 116 ---------- 117 features1, features2 : torch.Tensor 118 tensor of shape `(#batch, #features)` 119 Returns 120 ------- 121 loss : float 122 the loss value 123 124 """ 125 negative = torch.cat([features2[1:], features2[:1]]) 126 positive_distance = self.distance(features1, features2) 127 negative_distance = self.distance(features1, negative) 128 loss = torch.mean(self.nl(positive_distance - negative_distance + self.margin)) 129 return loss
Triplet loss for the pairwise SSL module.
A slightly modified version: a Softplus function is applied at the last step instead of ReLU to keep the result differentiable
81 def __init__(self, margin: float = 0, distance: str = "cosine"): 82 """Initialize the loss. 83 84 Parameters 85 ---------- 86 margin : float, default 0 87 the margin parameter 88 distance : {'cosine', 'euclidean'} 89 the distance metric (cosine similarity or euclidean distance) 90 91 """ 92 super().__init__() 93 self.margin = margin 94 self.nl = nn.Softplus() 95 if distance == "euclidean": 96 self.distance = self._euclidean_distance 97 elif distance == "cosine": 98 self.distance = self._cosine_similarity 99 else: 100 raise ValueError( 101 f'The {distance} is not available, please choose from "euclidean" and "cosine"' 102 )
Initialize the loss.
Parameters
margin : float, default 0 the margin parameter distance : {'cosine', 'euclidean'} the distance metric (cosine similarity or euclidean distance)
112 def forward(self, features1, features2): 113 """Compute the loss. 114 115 Parameters 116 ---------- 117 features1, features2 : torch.Tensor 118 tensor of shape `(#batch, #features)` 119 Returns 120 ------- 121 loss : float 122 the loss value 123 124 """ 125 negative = torch.cat([features2[1:], features2[:1]]) 126 positive_distance = self.distance(features1, features2) 127 negative_distance = self.distance(features1, negative) 128 loss = torch.mean(self.nl(positive_distance - negative_distance + self.margin)) 129 return loss
Compute the loss.
Parameters
features1, features2 : torch.Tensor
tensor of shape (#batch, #features)
Returns
loss : float the loss value
132class CircleLoss(nn.Module): 133 """Circle loss for the pairwise SSL module.""" 134 135 def __init__(self, gamma: float = 1, margin: float = 0, distance: str = "cosine"): 136 """Initialize the loss. 137 138 Parameters 139 ---------- 140 gamma : float, default 1 141 the gamma parameter 142 margin : float, default 0 143 the margin parameter 144 distance : {'cosine', 'euclidean'} 145 the distance metric (cosine similarity or euclidean distance) 146 147 """ 148 super().__init__() 149 self.gamma = gamma 150 self.margin = margin 151 if distance == "euclidean": 152 self.distance = self._euclidean_distance 153 elif distance == "cosine": 154 self.distance = self._cosine_similarity 155 else: 156 raise ValueError( 157 f'The {distance} is not available, please choose from "euclidean" and "cosine"' 158 ) 159 160 def _euclidean_distance(self, tensor1, tensor2): 161 """Compute euclidean distance.""" 162 return torch.sum((tensor1 - tensor2) ** 2, dim=1) 163 164 def _cosine_similarity(self, tensor1, tensor2): 165 """Compute cosine similarity.""" 166 return torch.cosine_similarity(tensor1, tensor2, dim=1) 167 168 def forward(self, features1, features2): 169 """Compute the loss. 170 171 Parameters 172 ---------- 173 features1, features2 : torch.Tensor 174 tensor of shape `(#batch, #features)` 175 176 Returns 177 ------- 178 loss : float 179 the loss value 180 181 """ 182 indices = list(combinations_with_replacement(list(range(len(features1))), 2)) 183 indices1, indices2 = map(list, zip(*indices)) 184 true = np.unique(indices1, return_index=True)[1] 185 mask = torch.zeros(len(indices1)).bool() 186 mask[true] = True 187 distances = self.distance(features1[indices1], features2[indices2]) 188 distances[mask] = distances[mask] + self.margin 189 distances[~mask] = distances[~mask] * (-1) 190 distances = torch.exp(self.gamma * distances) 191 loss = torch.log(1 + torch.sum(distances[mask]) * torch.sum(distances[~mask])) 192 return loss
Circle loss for the pairwise SSL module.
135 def __init__(self, gamma: float = 1, margin: float = 0, distance: str = "cosine"): 136 """Initialize the loss. 137 138 Parameters 139 ---------- 140 gamma : float, default 1 141 the gamma parameter 142 margin : float, default 0 143 the margin parameter 144 distance : {'cosine', 'euclidean'} 145 the distance metric (cosine similarity or euclidean distance) 146 147 """ 148 super().__init__() 149 self.gamma = gamma 150 self.margin = margin 151 if distance == "euclidean": 152 self.distance = self._euclidean_distance 153 elif distance == "cosine": 154 self.distance = self._cosine_similarity 155 else: 156 raise ValueError( 157 f'The {distance} is not available, please choose from "euclidean" and "cosine"' 158 )
Initialize the loss.
Parameters
gamma : float, default 1 the gamma parameter margin : float, default 0 the margin parameter distance : {'cosine', 'euclidean'} the distance metric (cosine similarity or euclidean distance)
168 def forward(self, features1, features2): 169 """Compute the loss. 170 171 Parameters 172 ---------- 173 features1, features2 : torch.Tensor 174 tensor of shape `(#batch, #features)` 175 176 Returns 177 ------- 178 loss : float 179 the loss value 180 181 """ 182 indices = list(combinations_with_replacement(list(range(len(features1))), 2)) 183 indices1, indices2 = map(list, zip(*indices)) 184 true = np.unique(indices1, return_index=True)[1] 185 mask = torch.zeros(len(indices1)).bool() 186 mask[true] = True 187 distances = self.distance(features1[indices1], features2[indices2]) 188 distances[mask] = distances[mask] + self.margin 189 distances[~mask] = distances[~mask] * (-1) 190 distances = torch.exp(self.gamma * distances) 191 loss = torch.log(1 + torch.sum(distances[mask]) * torch.sum(distances[~mask])) 192 return loss
Compute the loss.
Parameters
features1, features2 : torch.Tensor
tensor of shape (#batch, #features)
Returns
loss : float the loss value