dlc2action.loss.contrastive

Losses used by contrastive SSL constructors (see dlc2action.ssl.contrastive).

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""Losses used by contrastive SSL constructors (see `dlc2action.ssl.contrastive`)."""
  8
  9from itertools import combinations_with_replacement
 10
 11import numpy as np
 12import torch
 13from torch import nn
 14
 15
 16class NTXent(nn.Module):
 17    """NT-Xent loss for the contrastive SSL module."""
 18
 19    def __init__(self, tau: float):
 20        """Initialize the loss.
 21
 22        Parameters
 23        ----------
 24        tau : float
 25            the tau parameter
 26
 27        """
 28        super().__init__()
 29        self.tau = tau
 30
 31    def _exp_similarity(self, tensor1, tensor2):
 32        """Compute exponential similarity."""
 33        s = torch.cosine_similarity(tensor1, tensor2, dim=1) / self.tau
 34        s = torch.exp(s)
 35        return s
 36
 37    def _loss(self, similarity, true_indices, denom):
 38        """Compute one of the symmetric components of the loss."""
 39        l = torch.log(similarity[true_indices] / denom)
 40        l = -torch.mean(l)
 41        return l
 42
 43    def forward(self, features1, features2):
 44        """Compute the loss.
 45
 46        Parameters
 47        ----------
 48        features1, features2 : torch.Tensor
 49            tensor of shape `(#batch, #features)`
 50        Returns
 51        -------
 52        loss : float
 53            the loss value
 54
 55        """
 56        indices = list(combinations_with_replacement(list(range(len(features1))), 2))
 57        if len(indices) < 3:
 58            return 0
 59        indices1, indices2 = map(list, zip(*indices))
 60        true = np.unique(indices1, return_index=True)[1]
 61        similarity1 = self._exp_similarity(features1[indices1], features1[indices2])
 62        similarity2 = self._exp_similarity(features2[indices1], features2[indices2])
 63        similarity12 = self._exp_similarity(features1[indices1], features2[indices2])
 64        sum1 = similarity1.sum() - similarity1[true].sum()
 65        sum2 = similarity2.sum() - similarity2[true].sum()
 66        sum12 = similarity12.sum()
 67        loss = self._loss(similarity12, true, sum12 + sum1) + self._loss(
 68            similarity12, true, sum12 + sum2
 69        )
 70        return loss
 71
 72
 73class TripletLoss(nn.Module):
 74    """Triplet loss for the pairwise SSL module.
 75
 76    A slightly modified version: a Softplus function is applied at the last step instead of ReLU to keep
 77    the result differentiable
 78    """
 79
 80    def __init__(self, margin: float = 0, distance: str = "cosine"):
 81        """Initialize the loss.
 82
 83        Parameters
 84        ----------
 85        margin : float, default 0
 86            the margin parameter
 87        distance : {'cosine', 'euclidean'}
 88            the distance metric (cosine similarity or euclidean distance)
 89
 90        """
 91        super().__init__()
 92        self.margin = margin
 93        self.nl = nn.Softplus()
 94        if distance == "euclidean":
 95            self.distance = self._euclidean_distance
 96        elif distance == "cosine":
 97            self.distance = self._cosine_similarity
 98        else:
 99            raise ValueError(
100                f'The {distance} is not available, please choose from "euclidean" and "cosine"'
101            )
102
103    def _euclidean_distance(self, tensor1, tensor2):
104        """Compute euclidean distance."""
105        return torch.sum((tensor1 - tensor2) ** 2, dim=1)
106
107    def _cosine_similarity(self, tensor1, tensor2):
108        """Compute cosine similarity."""
109        return torch.cosine_similarity(tensor1, tensor2, dim=1)
110
111    def forward(self, features1, features2):
112        """Compute the loss.
113
114        Parameters
115        ----------
116        features1, features2 : torch.Tensor
117            tensor of shape `(#batch, #features)`
118        Returns
119        -------
120        loss : float
121            the loss value
122
123        """
124        negative = torch.cat([features2[1:], features2[:1]])
125        positive_distance = self.distance(features1, features2)
126        negative_distance = self.distance(features1, negative)
127        loss = torch.mean(self.nl(positive_distance - negative_distance + self.margin))
128        return loss
129
130
131class CircleLoss(nn.Module):
132    """Circle loss for the pairwise SSL module."""
133
134    def __init__(self, gamma: float = 1, margin: float = 0, distance: str = "cosine"):
135        """Initialize the loss.
136
137        Parameters
138        ----------
139        gamma : float, default 1
140            the gamma parameter
141        margin : float, default 0
142            the margin parameter
143        distance : {'cosine', 'euclidean'}
144            the distance metric (cosine similarity or euclidean distance)
145
146        """
147        super().__init__()
148        self.gamma = gamma
149        self.margin = margin
150        if distance == "euclidean":
151            self.distance = self._euclidean_distance
152        elif distance == "cosine":
153            self.distance = self._cosine_similarity
154        else:
155            raise ValueError(
156                f'The {distance} is not available, please choose from "euclidean" and "cosine"'
157            )
158
159    def _euclidean_distance(self, tensor1, tensor2):
160        """Compute euclidean distance."""
161        return torch.sum((tensor1 - tensor2) ** 2, dim=1)
162
163    def _cosine_similarity(self, tensor1, tensor2):
164        """Compute cosine similarity."""
165        return torch.cosine_similarity(tensor1, tensor2, dim=1)
166
167    def forward(self, features1, features2):
168        """Compute the loss.
169
170        Parameters
171        ----------
172        features1, features2 : torch.Tensor
173            tensor of shape `(#batch, #features)`
174
175        Returns
176        -------
177        loss : float
178            the loss value
179
180        """
181        indices = list(combinations_with_replacement(list(range(len(features1))), 2))
182        indices1, indices2 = map(list, zip(*indices))
183        true = np.unique(indices1, return_index=True)[1]
184        mask = torch.zeros(len(indices1)).bool()
185        mask[true] = True
186        distances = self.distance(features1[indices1], features2[indices2])
187        distances[mask] = distances[mask] + self.margin
188        distances[~mask] = distances[~mask] * (-1)
189        distances = torch.exp(self.gamma * distances)
190        loss = torch.log(1 + torch.sum(distances[mask]) * torch.sum(distances[~mask]))
191        return loss
class NTXent(torch.nn.modules.module.Module):
17class NTXent(nn.Module):
18    """NT-Xent loss for the contrastive SSL module."""
19
20    def __init__(self, tau: float):
21        """Initialize the loss.
22
23        Parameters
24        ----------
25        tau : float
26            the tau parameter
27
28        """
29        super().__init__()
30        self.tau = tau
31
32    def _exp_similarity(self, tensor1, tensor2):
33        """Compute exponential similarity."""
34        s = torch.cosine_similarity(tensor1, tensor2, dim=1) / self.tau
35        s = torch.exp(s)
36        return s
37
38    def _loss(self, similarity, true_indices, denom):
39        """Compute one of the symmetric components of the loss."""
40        l = torch.log(similarity[true_indices] / denom)
41        l = -torch.mean(l)
42        return l
43
44    def forward(self, features1, features2):
45        """Compute the loss.
46
47        Parameters
48        ----------
49        features1, features2 : torch.Tensor
50            tensor of shape `(#batch, #features)`
51        Returns
52        -------
53        loss : float
54            the loss value
55
56        """
57        indices = list(combinations_with_replacement(list(range(len(features1))), 2))
58        if len(indices) < 3:
59            return 0
60        indices1, indices2 = map(list, zip(*indices))
61        true = np.unique(indices1, return_index=True)[1]
62        similarity1 = self._exp_similarity(features1[indices1], features1[indices2])
63        similarity2 = self._exp_similarity(features2[indices1], features2[indices2])
64        similarity12 = self._exp_similarity(features1[indices1], features2[indices2])
65        sum1 = similarity1.sum() - similarity1[true].sum()
66        sum2 = similarity2.sum() - similarity2[true].sum()
67        sum12 = similarity12.sum()
68        loss = self._loss(similarity12, true, sum12 + sum1) + self._loss(
69            similarity12, true, sum12 + sum2
70        )
71        return loss

NT-Xent loss for the contrastive SSL module.

NTXent(tau: float)
20    def __init__(self, tau: float):
21        """Initialize the loss.
22
23        Parameters
24        ----------
25        tau : float
26            the tau parameter
27
28        """
29        super().__init__()
30        self.tau = tau

Initialize the loss.

Parameters

tau : float the tau parameter

tau
def forward(self, features1, features2):
44    def forward(self, features1, features2):
45        """Compute the loss.
46
47        Parameters
48        ----------
49        features1, features2 : torch.Tensor
50            tensor of shape `(#batch, #features)`
51        Returns
52        -------
53        loss : float
54            the loss value
55
56        """
57        indices = list(combinations_with_replacement(list(range(len(features1))), 2))
58        if len(indices) < 3:
59            return 0
60        indices1, indices2 = map(list, zip(*indices))
61        true = np.unique(indices1, return_index=True)[1]
62        similarity1 = self._exp_similarity(features1[indices1], features1[indices2])
63        similarity2 = self._exp_similarity(features2[indices1], features2[indices2])
64        similarity12 = self._exp_similarity(features1[indices1], features2[indices2])
65        sum1 = similarity1.sum() - similarity1[true].sum()
66        sum2 = similarity2.sum() - similarity2[true].sum()
67        sum12 = similarity12.sum()
68        loss = self._loss(similarity12, true, sum12 + sum1) + self._loss(
69            similarity12, true, sum12 + sum2
70        )
71        return loss

Compute the loss.

Parameters

features1, features2 : torch.Tensor tensor of shape (#batch, #features)

Returns

loss : float the loss value

class TripletLoss(torch.nn.modules.module.Module):
 74class TripletLoss(nn.Module):
 75    """Triplet loss for the pairwise SSL module.
 76
 77    A slightly modified version: a Softplus function is applied at the last step instead of ReLU to keep
 78    the result differentiable
 79    """
 80
 81    def __init__(self, margin: float = 0, distance: str = "cosine"):
 82        """Initialize the loss.
 83
 84        Parameters
 85        ----------
 86        margin : float, default 0
 87            the margin parameter
 88        distance : {'cosine', 'euclidean'}
 89            the distance metric (cosine similarity or euclidean distance)
 90
 91        """
 92        super().__init__()
 93        self.margin = margin
 94        self.nl = nn.Softplus()
 95        if distance == "euclidean":
 96            self.distance = self._euclidean_distance
 97        elif distance == "cosine":
 98            self.distance = self._cosine_similarity
 99        else:
100            raise ValueError(
101                f'The {distance} is not available, please choose from "euclidean" and "cosine"'
102            )
103
104    def _euclidean_distance(self, tensor1, tensor2):
105        """Compute euclidean distance."""
106        return torch.sum((tensor1 - tensor2) ** 2, dim=1)
107
108    def _cosine_similarity(self, tensor1, tensor2):
109        """Compute cosine similarity."""
110        return torch.cosine_similarity(tensor1, tensor2, dim=1)
111
112    def forward(self, features1, features2):
113        """Compute the loss.
114
115        Parameters
116        ----------
117        features1, features2 : torch.Tensor
118            tensor of shape `(#batch, #features)`
119        Returns
120        -------
121        loss : float
122            the loss value
123
124        """
125        negative = torch.cat([features2[1:], features2[:1]])
126        positive_distance = self.distance(features1, features2)
127        negative_distance = self.distance(features1, negative)
128        loss = torch.mean(self.nl(positive_distance - negative_distance + self.margin))
129        return loss

Triplet loss for the pairwise SSL module.

A slightly modified version: a Softplus function is applied at the last step instead of ReLU to keep the result differentiable

TripletLoss(margin: float = 0, distance: str = 'cosine')
 81    def __init__(self, margin: float = 0, distance: str = "cosine"):
 82        """Initialize the loss.
 83
 84        Parameters
 85        ----------
 86        margin : float, default 0
 87            the margin parameter
 88        distance : {'cosine', 'euclidean'}
 89            the distance metric (cosine similarity or euclidean distance)
 90
 91        """
 92        super().__init__()
 93        self.margin = margin
 94        self.nl = nn.Softplus()
 95        if distance == "euclidean":
 96            self.distance = self._euclidean_distance
 97        elif distance == "cosine":
 98            self.distance = self._cosine_similarity
 99        else:
100            raise ValueError(
101                f'The {distance} is not available, please choose from "euclidean" and "cosine"'
102            )

Initialize the loss.

Parameters

margin : float, default 0 the margin parameter distance : {'cosine', 'euclidean'} the distance metric (cosine similarity or euclidean distance)

margin
nl
def forward(self, features1, features2):
112    def forward(self, features1, features2):
113        """Compute the loss.
114
115        Parameters
116        ----------
117        features1, features2 : torch.Tensor
118            tensor of shape `(#batch, #features)`
119        Returns
120        -------
121        loss : float
122            the loss value
123
124        """
125        negative = torch.cat([features2[1:], features2[:1]])
126        positive_distance = self.distance(features1, features2)
127        negative_distance = self.distance(features1, negative)
128        loss = torch.mean(self.nl(positive_distance - negative_distance + self.margin))
129        return loss

Compute the loss.

Parameters

features1, features2 : torch.Tensor tensor of shape (#batch, #features)

Returns

loss : float the loss value

class CircleLoss(torch.nn.modules.module.Module):
132class CircleLoss(nn.Module):
133    """Circle loss for the pairwise SSL module."""
134
135    def __init__(self, gamma: float = 1, margin: float = 0, distance: str = "cosine"):
136        """Initialize the loss.
137
138        Parameters
139        ----------
140        gamma : float, default 1
141            the gamma parameter
142        margin : float, default 0
143            the margin parameter
144        distance : {'cosine', 'euclidean'}
145            the distance metric (cosine similarity or euclidean distance)
146
147        """
148        super().__init__()
149        self.gamma = gamma
150        self.margin = margin
151        if distance == "euclidean":
152            self.distance = self._euclidean_distance
153        elif distance == "cosine":
154            self.distance = self._cosine_similarity
155        else:
156            raise ValueError(
157                f'The {distance} is not available, please choose from "euclidean" and "cosine"'
158            )
159
160    def _euclidean_distance(self, tensor1, tensor2):
161        """Compute euclidean distance."""
162        return torch.sum((tensor1 - tensor2) ** 2, dim=1)
163
164    def _cosine_similarity(self, tensor1, tensor2):
165        """Compute cosine similarity."""
166        return torch.cosine_similarity(tensor1, tensor2, dim=1)
167
168    def forward(self, features1, features2):
169        """Compute the loss.
170
171        Parameters
172        ----------
173        features1, features2 : torch.Tensor
174            tensor of shape `(#batch, #features)`
175
176        Returns
177        -------
178        loss : float
179            the loss value
180
181        """
182        indices = list(combinations_with_replacement(list(range(len(features1))), 2))
183        indices1, indices2 = map(list, zip(*indices))
184        true = np.unique(indices1, return_index=True)[1]
185        mask = torch.zeros(len(indices1)).bool()
186        mask[true] = True
187        distances = self.distance(features1[indices1], features2[indices2])
188        distances[mask] = distances[mask] + self.margin
189        distances[~mask] = distances[~mask] * (-1)
190        distances = torch.exp(self.gamma * distances)
191        loss = torch.log(1 + torch.sum(distances[mask]) * torch.sum(distances[~mask]))
192        return loss

Circle loss for the pairwise SSL module.

CircleLoss(gamma: float = 1, margin: float = 0, distance: str = 'cosine')
135    def __init__(self, gamma: float = 1, margin: float = 0, distance: str = "cosine"):
136        """Initialize the loss.
137
138        Parameters
139        ----------
140        gamma : float, default 1
141            the gamma parameter
142        margin : float, default 0
143            the margin parameter
144        distance : {'cosine', 'euclidean'}
145            the distance metric (cosine similarity or euclidean distance)
146
147        """
148        super().__init__()
149        self.gamma = gamma
150        self.margin = margin
151        if distance == "euclidean":
152            self.distance = self._euclidean_distance
153        elif distance == "cosine":
154            self.distance = self._cosine_similarity
155        else:
156            raise ValueError(
157                f'The {distance} is not available, please choose from "euclidean" and "cosine"'
158            )

Initialize the loss.

Parameters

gamma : float, default 1 the gamma parameter margin : float, default 0 the margin parameter distance : {'cosine', 'euclidean'} the distance metric (cosine similarity or euclidean distance)

gamma
margin
def forward(self, features1, features2):
168    def forward(self, features1, features2):
169        """Compute the loss.
170
171        Parameters
172        ----------
173        features1, features2 : torch.Tensor
174            tensor of shape `(#batch, #features)`
175
176        Returns
177        -------
178        loss : float
179            the loss value
180
181        """
182        indices = list(combinations_with_replacement(list(range(len(features1))), 2))
183        indices1, indices2 = map(list, zip(*indices))
184        true = np.unique(indices1, return_index=True)[1]
185        mask = torch.zeros(len(indices1)).bool()
186        mask[true] = True
187        distances = self.distance(features1[indices1], features2[indices2])
188        distances[mask] = distances[mask] + self.margin
189        distances[~mask] = distances[~mask] * (-1)
190        distances = torch.exp(self.gamma * distances)
191        loss = torch.log(1 + torch.sum(distances[mask]) * torch.sum(distances[~mask]))
192        return loss

Compute the loss.

Parameters

features1, features2 : torch.Tensor tensor of shape (#batch, #features)

Returns

loss : float the loss value