1#
2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
3#
4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
5#
6from torch import nn
7import torch
8
9
10class _ContrastiveRegressionLoss(nn.Module):
11 def __init__(self, temperature: float, distance: str, break_factor: int):
12 self.temp = temperature
13 self.distance = distance
14 self.break_factor = break_factor
15 assert distance in ["l1", "l2", "cosine"]
16 super(_ContrastiveRegressionLoss, self).__init__()
17
18 def _get_distance_matrix(self, tensor1, tensor2) -> torch.Tensor:
19 """
20 Shape (#features, #frames)
21 """
22
23 if self.distance == "l1":
24 dist = torch.cdist(tensor1.T, tensor2.T, p=1)
25 elif self.distance == "l2":
26 dist = torch.cdist(tensor1.T, tensor2.T, p=2)
27 else:
28 dist = torch.nn.functional.cosine_similarity(
29 tensor1.t()[:, :, None], tensor2[None, :, :]
30 )
31 return dist
32
33 def forward(self, tensor1, tensor2):
34 """
35 Compute the loss
36
37 Parameters
38 ----------
39 tensor1, tensor2 : torch.Tensor
40 tensor of shape `(#batch, #features, #frames)`
41 Returns
42 -------
43 loss : float
44 the loss value
45 """
46
47 loss = 0
48 if self.break_factor is not None:
49 B, C, T = tensor1.shape
50 if T % self.break_factor != 0:
51 tensor1 = tensor1[:, :, : -(T % self.break_factor)]
52 tensor2 = tensor2[:, :, : -(T % self.break_factor)]
53 T -= T % self.break_factor
54 tensor1 = tensor1.reshape((B, C, self.break_factor, -1))
55 tensor1 = torch.transpose(tensor1, 1, 2).reshape(
56 (B * self.break_factor, C, -1)
57 )
58 tensor2 = tensor2.reshape((B, C, self.break_factor, -1))
59 tensor2 = torch.transpose(tensor2, 1, 2).reshape(
60 (B * self.break_factor, C, -1)
61 )
62 indices = torch.tensor(range(tensor1.shape[-1])).to(tensor1.device)
63 for i in range(tensor1.shape[0]):
64 out = torch.exp(
65 self._get_distance_matrix(tensor1[i], tensor2[i]) / self.temp
66 )
67 out = out / (torch.sum(out, 1).unsqueeze(1) + 1e-7)
68 out = torch.sum(out * indices.unsqueeze(0), 1)
69 loss += torch.sum((out - indices) ** 2)
70 return loss / (tensor1.shape[-1] * tensor1.shape[0])