dlc2action.loss.contrastive_frame

 1#
 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
 5#
 6from torch import nn
 7import torch
 8
 9
10class _ContrastiveRegressionLoss(nn.Module):
11    def __init__(self, temperature: float, distance: str, break_factor: int):
12        self.temp = temperature
13        self.distance = distance
14        self.break_factor = break_factor
15        assert distance in ["l1", "l2", "cosine"]
16        super(_ContrastiveRegressionLoss, self).__init__()
17
18    def _get_distance_matrix(self, tensor1, tensor2) -> torch.Tensor:
19        """
20        Shape (#features, #frames)
21        """
22
23        if self.distance == "l1":
24            dist = torch.cdist(tensor1.T, tensor2.T, p=1)
25        elif self.distance == "l2":
26            dist = torch.cdist(tensor1.T, tensor2.T, p=2)
27        else:
28            dist = torch.nn.functional.cosine_similarity(
29                tensor1.t()[:, :, None], tensor2[None, :, :]
30            )
31        return dist
32
33    def forward(self, tensor1, tensor2):
34        """
35        Compute the loss
36
37        Parameters
38        ----------
39        tensor1, tensor2 : torch.Tensor
40            tensor of shape `(#batch, #features, #frames)`
41        Returns
42        -------
43        loss : float
44            the loss value
45        """
46
47        loss = 0
48        if self.break_factor is not None:
49            B, C, T = tensor1.shape
50            if T % self.break_factor != 0:
51                tensor1 = tensor1[:, :, : -(T % self.break_factor)]
52                tensor2 = tensor2[:, :, : -(T % self.break_factor)]
53                T -= T % self.break_factor
54            tensor1 = tensor1.reshape((B, C, self.break_factor, -1))
55            tensor1 = torch.transpose(tensor1, 1, 2).reshape(
56                (B * self.break_factor, C, -1)
57            )
58            tensor2 = tensor2.reshape((B, C, self.break_factor, -1))
59            tensor2 = torch.transpose(tensor2, 1, 2).reshape(
60                (B * self.break_factor, C, -1)
61            )
62        indices = torch.tensor(range(tensor1.shape[-1])).to(tensor1.device)
63        for i in range(tensor1.shape[0]):
64            out = torch.exp(
65                self._get_distance_matrix(tensor1[i], tensor2[i]) / self.temp
66            )
67            out = out / (torch.sum(out, 1).unsqueeze(1) + 1e-7)
68            out = torch.sum(out * indices.unsqueeze(0), 1)
69            loss += torch.sum((out - indices) ** 2)
70        return loss / (tensor1.shape[-1] * tensor1.shape[0])