dlc2action.loss.contrastive

Losses used by contrastive SSL constructors (see dlc2action.ssl.contrastive)

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Losses used by contrastive SSL constructors (see `dlc2action.ssl.contrastive`)
  8"""
  9
 10from torch import nn
 11import torch
 12import numpy as np
 13from itertools import combinations_with_replacement
 14
 15
 16class _NTXent(nn.Module):
 17    """
 18    NT-Xent loss for the contrastive SSL module
 19    """
 20
 21    def __init__(self, tau: float):
 22        """
 23        Parameters
 24        ----------
 25        tau : float
 26            the tau parameter
 27        """
 28
 29        super().__init__()
 30        self.tau = tau
 31
 32    def _exp_similarity(self, tensor1, tensor2):
 33        """
 34        Compute exponential similarity
 35        """
 36
 37        s = torch.cosine_similarity(tensor1, tensor2, dim=1) / self.tau
 38        s = torch.exp(s)
 39        return s
 40
 41    def _loss(self, similarity, true_indices, denom):
 42        """
 43        Compute one of the symmetric components of the loss
 44        """
 45
 46        l = torch.log(similarity[true_indices] / denom)
 47        l = -torch.mean(l)
 48        return l
 49
 50    def forward(self, features1, features2):
 51        """
 52        Compute the loss
 53        Parameters
 54        ----------
 55        features1, features2 : torch.Tensor
 56            tensor of shape `(#batch, #features)`
 57        Returns
 58        -------
 59        loss : float
 60            the loss value
 61        """
 62
 63        indices = list(combinations_with_replacement(list(range(len(features1))), 2))
 64        if len(indices) < 3:
 65            return 0
 66        indices1, indices2 = map(list, zip(*indices))
 67        true = np.unique(indices1, return_index=True)[1]
 68        similarity1 = self._exp_similarity(features1[indices1], features1[indices2])
 69        similarity2 = self._exp_similarity(features2[indices1], features2[indices2])
 70        similarity12 = self._exp_similarity(features1[indices1], features2[indices2])
 71        sum1 = similarity1.sum() - similarity1[true].sum()
 72        sum2 = similarity2.sum() - similarity2[true].sum()
 73        sum12 = similarity12.sum()
 74        loss = self._loss(similarity12, true, sum12 + sum1) + self._loss(
 75            similarity12, true, sum12 + sum2
 76        )
 77        return loss
 78
 79
 80class _TripletLoss(nn.Module):
 81    """
 82    Triplet loss for the pairwise SSL module
 83    A slightly modified version: a Softplus function is applied at the last step instead of ReLU to keep
 84    the result differentiable
 85    """
 86
 87    def __init__(self, margin: float = 0, distance: str = "cosine"):
 88        """
 89        Parameters
 90        ----------
 91        margin : float, default 0
 92            the margin parameter
 93        distance : {'cosine', 'euclidean'}
 94            the distance metric (cosine similarity ot euclidean distance)
 95        """
 96
 97        super().__init__()
 98        self.margin = margin
 99        self.nl = nn.Softplus()
100        if distance == "euclidean":
101            self.distance = self._euclidean_distance
102        elif distance == "cosine":
103            self.distance = self._cosine_similarity
104        else:
105            raise ValueError(
106                f'The {distance} is not available, please choose from "euclidean" and "cosine"'
107            )
108
109    def _euclidean_distance(self, tensor1, tensor2):
110        """
111        Compute euclidean distance
112        """
113
114        return torch.sum((tensor1 - tensor2) ** 2, dim=1)
115
116    def _cosine_similarity(self, tensor1, tensor2):
117        """
118        Compute cosine similarity
119        """
120
121        return torch.cosine_similarity(tensor1, tensor2, dim=1)
122
123    def forward(self, features1, features2):
124        """
125        Compute the loss
126        Parameters
127        ----------
128        features1, features2 : torch.Tensor
129            tensor of shape `(#batch, #features)`
130        Returns
131        -------
132        loss : float
133            the loss value
134        """
135
136        negative = torch.cat([features2[1:], features2[:1]])
137        positive_distance = self.distance(features1, features2)
138        negative_distance = self.distance(features1, negative)
139        loss = torch.mean(self.nl(positive_distance - negative_distance + self.margin))
140        return loss
141
142
143class _CircleLoss(nn.Module):
144    """
145    Circle loss for the pairwise SSL module
146    """
147
148    def __init__(self, gamma: float = 1, margin: float = 0, distance: str = "cosine"):
149        """
150        Parameters
151        ----------
152        gamma : float, default 1
153            the gamma parameter
154        margin : float, default 0
155            the margin parameter
156        distance : {'cosine', 'euclidean'}
157            the distance metric (cosine similarity ot euclidean distance)
158        """
159
160        super().__init__()
161        self.gamma = gamma
162        self.margin = margin
163        if distance == "euclidean":
164            self.distance = self._euclidean_distance
165        elif distance == "cosine":
166            self.distance = self._cosine_similarity
167        else:
168            raise ValueError(
169                f'The {distance} is not available, please choose from "euclidean" and "cosine"'
170            )
171
172    def _euclidean_distance(self, tensor1, tensor2):
173        """
174        Compute euclidean distance
175        """
176
177        return torch.sum((tensor1 - tensor2) ** 2, dim=1)
178
179    def _cosine_similarity(self, tensor1, tensor2):
180        """
181        Compute cosine similarity
182        """
183
184        return torch.cosine_similarity(tensor1, tensor2, dim=1)
185
186    def forward(self, features1, features2):
187        """
188        Compute the loss
189        Parameters
190        ----------
191        features1, features2 : torch.Tensor
192            tensor of shape `(#batch, #features)`
193        Returns
194        -------
195        loss : float
196            the loss value
197        """
198
199        indices = list(combinations_with_replacement(list(range(len(features1))), 2))
200        indices1, indices2 = map(list, zip(*indices))
201        true = np.unique(indices1, return_index=True)[1]
202        mask = torch.zeros(len(indices1)).bool()
203        mask[true] = True
204        distances = self.distance(features1[indices1], features2[indices2])
205        distances[mask] = distances[mask] + self.margin
206        distances[~mask] = distances[~mask] * (-1)
207        distances = torch.exp(self.gamma * distances)
208        loss = torch.log(1 + torch.sum(distances[mask]) * torch.sum(distances[~mask]))
209        return loss