dlc2action.loss.contrastive_frame

 1#
 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. 
 5# A copy is included in dlc2action/LICENSE.AGPL.
 6#
 7import torch
 8from torch import nn
 9
10
11class ContrastiveRegressionLoss(nn.Module):
12    """Contrastive regression loss for the pairwise SSL module."""
13
14    def __init__(self, temperature: float, distance: str, break_factor: int):
15        """Initialize the loss.
16
17        Parameters
18        ----------
19        temperature : float
20            the temperature parameter
21        distance : {'cosine', 'l1', 'l2'}
22            the distance metric (cosine similarity or euclidean distance)
23        break_factor : int
24            the break factor for the length dimension
25
26        """
27        self.temp = temperature
28        self.distance = distance
29        self.break_factor = break_factor
30        assert distance in ["l1", "l2", "cosine"]
31        super(ContrastiveRegressionLoss, self).__init__()
32
33    def _get_distance_matrix(self, tensor1, tensor2) -> torch.Tensor:
34        # Shape (#features, #frames)
35        if self.distance == "l1":
36            dist = torch.cdist(tensor1.T, tensor2.T, p=1)
37        elif self.distance == "l2":
38            dist = torch.cdist(tensor1.T, tensor2.T, p=2)
39        else:
40            dist = torch.nn.functional.cosine_similarity(
41                tensor1.t()[:, :, None], tensor2[None, :, :]
42            )
43        return dist
44
45    def forward(self, tensor1, tensor2):
46        """Compute the loss.
47
48        Parameters
49        ----------
50        tensor1, tensor2 : torch.Tensor
51            tensor of shape `(#batch, #features, #frames)`
52
53        Returns
54        -------
55        loss : float
56            the loss value
57
58        """
59        loss = 0
60        if self.break_factor is not None:
61            B, C, T = tensor1.shape
62            if T % self.break_factor != 0:
63                tensor1 = tensor1[:, :, : -(T % self.break_factor)]
64                tensor2 = tensor2[:, :, : -(T % self.break_factor)]
65                T -= T % self.break_factor
66            tensor1 = tensor1.reshape((B, C, self.break_factor, -1))
67            tensor1 = torch.transpose(tensor1, 1, 2).reshape(
68                (B * self.break_factor, C, -1)
69            )
70            tensor2 = tensor2.reshape((B, C, self.break_factor, -1))
71            tensor2 = torch.transpose(tensor2, 1, 2).reshape(
72                (B * self.break_factor, C, -1)
73            )
74        indices = torch.tensor(range(tensor1.shape[-1])).to(tensor1.device)
75        for i in range(tensor1.shape[0]):
76            out = torch.exp(
77                self._get_distance_matrix(tensor1[i], tensor2[i]) / self.temp
78            )
79            out = out / (torch.sum(out, 1).unsqueeze(1) + 1e-7)
80            out = torch.sum(out * indices.unsqueeze(0), 1)
81            loss += torch.sum((out - indices) ** 2)
82        return loss / (tensor1.shape[-1] * tensor1.shape[0])
class ContrastiveRegressionLoss(torch.nn.modules.module.Module):
12class ContrastiveRegressionLoss(nn.Module):
13    """Contrastive regression loss for the pairwise SSL module."""
14
15    def __init__(self, temperature: float, distance: str, break_factor: int):
16        """Initialize the loss.
17
18        Parameters
19        ----------
20        temperature : float
21            the temperature parameter
22        distance : {'cosine', 'l1', 'l2'}
23            the distance metric (cosine similarity or euclidean distance)
24        break_factor : int
25            the break factor for the length dimension
26
27        """
28        self.temp = temperature
29        self.distance = distance
30        self.break_factor = break_factor
31        assert distance in ["l1", "l2", "cosine"]
32        super(ContrastiveRegressionLoss, self).__init__()
33
34    def _get_distance_matrix(self, tensor1, tensor2) -> torch.Tensor:
35        # Shape (#features, #frames)
36        if self.distance == "l1":
37            dist = torch.cdist(tensor1.T, tensor2.T, p=1)
38        elif self.distance == "l2":
39            dist = torch.cdist(tensor1.T, tensor2.T, p=2)
40        else:
41            dist = torch.nn.functional.cosine_similarity(
42                tensor1.t()[:, :, None], tensor2[None, :, :]
43            )
44        return dist
45
46    def forward(self, tensor1, tensor2):
47        """Compute the loss.
48
49        Parameters
50        ----------
51        tensor1, tensor2 : torch.Tensor
52            tensor of shape `(#batch, #features, #frames)`
53
54        Returns
55        -------
56        loss : float
57            the loss value
58
59        """
60        loss = 0
61        if self.break_factor is not None:
62            B, C, T = tensor1.shape
63            if T % self.break_factor != 0:
64                tensor1 = tensor1[:, :, : -(T % self.break_factor)]
65                tensor2 = tensor2[:, :, : -(T % self.break_factor)]
66                T -= T % self.break_factor
67            tensor1 = tensor1.reshape((B, C, self.break_factor, -1))
68            tensor1 = torch.transpose(tensor1, 1, 2).reshape(
69                (B * self.break_factor, C, -1)
70            )
71            tensor2 = tensor2.reshape((B, C, self.break_factor, -1))
72            tensor2 = torch.transpose(tensor2, 1, 2).reshape(
73                (B * self.break_factor, C, -1)
74            )
75        indices = torch.tensor(range(tensor1.shape[-1])).to(tensor1.device)
76        for i in range(tensor1.shape[0]):
77            out = torch.exp(
78                self._get_distance_matrix(tensor1[i], tensor2[i]) / self.temp
79            )
80            out = out / (torch.sum(out, 1).unsqueeze(1) + 1e-7)
81            out = torch.sum(out * indices.unsqueeze(0), 1)
82            loss += torch.sum((out - indices) ** 2)
83        return loss / (tensor1.shape[-1] * tensor1.shape[0])

Contrastive regression loss for the pairwise SSL module.

ContrastiveRegressionLoss(temperature: float, distance: str, break_factor: int)
15    def __init__(self, temperature: float, distance: str, break_factor: int):
16        """Initialize the loss.
17
18        Parameters
19        ----------
20        temperature : float
21            the temperature parameter
22        distance : {'cosine', 'l1', 'l2'}
23            the distance metric (cosine similarity or euclidean distance)
24        break_factor : int
25            the break factor for the length dimension
26
27        """
28        self.temp = temperature
29        self.distance = distance
30        self.break_factor = break_factor
31        assert distance in ["l1", "l2", "cosine"]
32        super(ContrastiveRegressionLoss, self).__init__()

Initialize the loss.

Parameters

temperature : float the temperature parameter distance : {'cosine', 'l1', 'l2'} the distance metric (cosine similarity or euclidean distance) break_factor : int the break factor for the length dimension

temp
distance
break_factor
def forward(self, tensor1, tensor2):
46    def forward(self, tensor1, tensor2):
47        """Compute the loss.
48
49        Parameters
50        ----------
51        tensor1, tensor2 : torch.Tensor
52            tensor of shape `(#batch, #features, #frames)`
53
54        Returns
55        -------
56        loss : float
57            the loss value
58
59        """
60        loss = 0
61        if self.break_factor is not None:
62            B, C, T = tensor1.shape
63            if T % self.break_factor != 0:
64                tensor1 = tensor1[:, :, : -(T % self.break_factor)]
65                tensor2 = tensor2[:, :, : -(T % self.break_factor)]
66                T -= T % self.break_factor
67            tensor1 = tensor1.reshape((B, C, self.break_factor, -1))
68            tensor1 = torch.transpose(tensor1, 1, 2).reshape(
69                (B * self.break_factor, C, -1)
70            )
71            tensor2 = tensor2.reshape((B, C, self.break_factor, -1))
72            tensor2 = torch.transpose(tensor2, 1, 2).reshape(
73                (B * self.break_factor, C, -1)
74            )
75        indices = torch.tensor(range(tensor1.shape[-1])).to(tensor1.device)
76        for i in range(tensor1.shape[0]):
77            out = torch.exp(
78                self._get_distance_matrix(tensor1[i], tensor2[i]) / self.temp
79            )
80            out = out / (torch.sum(out, 1).unsqueeze(1) + 1e-7)
81            out = torch.sum(out * indices.unsqueeze(0), 1)
82            loss += torch.sum((out - indices) ** 2)
83        return loss / (tensor1.shape[-1] * tensor1.shape[0])

Compute the loss.

Parameters

tensor1, tensor2 : torch.Tensor tensor of shape (#batch, #features, #frames)

Returns

loss : float the loss value