dlc2action.loss.contrastive_frame
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7import torch 8from torch import nn 9 10 11class ContrastiveRegressionLoss(nn.Module): 12 """Contrastive regression loss for the pairwise SSL module.""" 13 14 def __init__(self, temperature: float, distance: str, break_factor: int): 15 """Initialize the loss. 16 17 Parameters 18 ---------- 19 temperature : float 20 the temperature parameter 21 distance : {'cosine', 'l1', 'l2'} 22 the distance metric (cosine similarity or euclidean distance) 23 break_factor : int 24 the break factor for the length dimension 25 26 """ 27 self.temp = temperature 28 self.distance = distance 29 self.break_factor = break_factor 30 assert distance in ["l1", "l2", "cosine"] 31 super(ContrastiveRegressionLoss, self).__init__() 32 33 def _get_distance_matrix(self, tensor1, tensor2) -> torch.Tensor: 34 # Shape (#features, #frames) 35 if self.distance == "l1": 36 dist = torch.cdist(tensor1.T, tensor2.T, p=1) 37 elif self.distance == "l2": 38 dist = torch.cdist(tensor1.T, tensor2.T, p=2) 39 else: 40 dist = torch.nn.functional.cosine_similarity( 41 tensor1.t()[:, :, None], tensor2[None, :, :] 42 ) 43 return dist 44 45 def forward(self, tensor1, tensor2): 46 """Compute the loss. 47 48 Parameters 49 ---------- 50 tensor1, tensor2 : torch.Tensor 51 tensor of shape `(#batch, #features, #frames)` 52 53 Returns 54 ------- 55 loss : float 56 the loss value 57 58 """ 59 loss = 0 60 if self.break_factor is not None: 61 B, C, T = tensor1.shape 62 if T % self.break_factor != 0: 63 tensor1 = tensor1[:, :, : -(T % self.break_factor)] 64 tensor2 = tensor2[:, :, : -(T % self.break_factor)] 65 T -= T % self.break_factor 66 tensor1 = tensor1.reshape((B, C, self.break_factor, -1)) 67 tensor1 = torch.transpose(tensor1, 1, 2).reshape( 68 (B * self.break_factor, C, -1) 69 ) 70 tensor2 = tensor2.reshape((B, C, self.break_factor, -1)) 71 tensor2 = torch.transpose(tensor2, 1, 2).reshape( 72 (B * self.break_factor, C, -1) 73 ) 74 indices = torch.tensor(range(tensor1.shape[-1])).to(tensor1.device) 75 for i in range(tensor1.shape[0]): 76 out = torch.exp( 77 self._get_distance_matrix(tensor1[i], tensor2[i]) / self.temp 78 ) 79 out = out / (torch.sum(out, 1).unsqueeze(1) + 1e-7) 80 out = torch.sum(out * indices.unsqueeze(0), 1) 81 loss += torch.sum((out - indices) ** 2) 82 return loss / (tensor1.shape[-1] * tensor1.shape[0])
class
ContrastiveRegressionLoss(torch.nn.modules.module.Module):
12class ContrastiveRegressionLoss(nn.Module): 13 """Contrastive regression loss for the pairwise SSL module.""" 14 15 def __init__(self, temperature: float, distance: str, break_factor: int): 16 """Initialize the loss. 17 18 Parameters 19 ---------- 20 temperature : float 21 the temperature parameter 22 distance : {'cosine', 'l1', 'l2'} 23 the distance metric (cosine similarity or euclidean distance) 24 break_factor : int 25 the break factor for the length dimension 26 27 """ 28 self.temp = temperature 29 self.distance = distance 30 self.break_factor = break_factor 31 assert distance in ["l1", "l2", "cosine"] 32 super(ContrastiveRegressionLoss, self).__init__() 33 34 def _get_distance_matrix(self, tensor1, tensor2) -> torch.Tensor: 35 # Shape (#features, #frames) 36 if self.distance == "l1": 37 dist = torch.cdist(tensor1.T, tensor2.T, p=1) 38 elif self.distance == "l2": 39 dist = torch.cdist(tensor1.T, tensor2.T, p=2) 40 else: 41 dist = torch.nn.functional.cosine_similarity( 42 tensor1.t()[:, :, None], tensor2[None, :, :] 43 ) 44 return dist 45 46 def forward(self, tensor1, tensor2): 47 """Compute the loss. 48 49 Parameters 50 ---------- 51 tensor1, tensor2 : torch.Tensor 52 tensor of shape `(#batch, #features, #frames)` 53 54 Returns 55 ------- 56 loss : float 57 the loss value 58 59 """ 60 loss = 0 61 if self.break_factor is not None: 62 B, C, T = tensor1.shape 63 if T % self.break_factor != 0: 64 tensor1 = tensor1[:, :, : -(T % self.break_factor)] 65 tensor2 = tensor2[:, :, : -(T % self.break_factor)] 66 T -= T % self.break_factor 67 tensor1 = tensor1.reshape((B, C, self.break_factor, -1)) 68 tensor1 = torch.transpose(tensor1, 1, 2).reshape( 69 (B * self.break_factor, C, -1) 70 ) 71 tensor2 = tensor2.reshape((B, C, self.break_factor, -1)) 72 tensor2 = torch.transpose(tensor2, 1, 2).reshape( 73 (B * self.break_factor, C, -1) 74 ) 75 indices = torch.tensor(range(tensor1.shape[-1])).to(tensor1.device) 76 for i in range(tensor1.shape[0]): 77 out = torch.exp( 78 self._get_distance_matrix(tensor1[i], tensor2[i]) / self.temp 79 ) 80 out = out / (torch.sum(out, 1).unsqueeze(1) + 1e-7) 81 out = torch.sum(out * indices.unsqueeze(0), 1) 82 loss += torch.sum((out - indices) ** 2) 83 return loss / (tensor1.shape[-1] * tensor1.shape[0])
Contrastive regression loss for the pairwise SSL module.
ContrastiveRegressionLoss(temperature: float, distance: str, break_factor: int)
15 def __init__(self, temperature: float, distance: str, break_factor: int): 16 """Initialize the loss. 17 18 Parameters 19 ---------- 20 temperature : float 21 the temperature parameter 22 distance : {'cosine', 'l1', 'l2'} 23 the distance metric (cosine similarity or euclidean distance) 24 break_factor : int 25 the break factor for the length dimension 26 27 """ 28 self.temp = temperature 29 self.distance = distance 30 self.break_factor = break_factor 31 assert distance in ["l1", "l2", "cosine"] 32 super(ContrastiveRegressionLoss, self).__init__()
Initialize the loss.
Parameters
temperature : float the temperature parameter distance : {'cosine', 'l1', 'l2'} the distance metric (cosine similarity or euclidean distance) break_factor : int the break factor for the length dimension
def
forward(self, tensor1, tensor2):
46 def forward(self, tensor1, tensor2): 47 """Compute the loss. 48 49 Parameters 50 ---------- 51 tensor1, tensor2 : torch.Tensor 52 tensor of shape `(#batch, #features, #frames)` 53 54 Returns 55 ------- 56 loss : float 57 the loss value 58 59 """ 60 loss = 0 61 if self.break_factor is not None: 62 B, C, T = tensor1.shape 63 if T % self.break_factor != 0: 64 tensor1 = tensor1[:, :, : -(T % self.break_factor)] 65 tensor2 = tensor2[:, :, : -(T % self.break_factor)] 66 T -= T % self.break_factor 67 tensor1 = tensor1.reshape((B, C, self.break_factor, -1)) 68 tensor1 = torch.transpose(tensor1, 1, 2).reshape( 69 (B * self.break_factor, C, -1) 70 ) 71 tensor2 = tensor2.reshape((B, C, self.break_factor, -1)) 72 tensor2 = torch.transpose(tensor2, 1, 2).reshape( 73 (B * self.break_factor, C, -1) 74 ) 75 indices = torch.tensor(range(tensor1.shape[-1])).to(tensor1.device) 76 for i in range(tensor1.shape[0]): 77 out = torch.exp( 78 self._get_distance_matrix(tensor1[i], tensor2[i]) / self.temp 79 ) 80 out = out / (torch.sum(out, 1).unsqueeze(1) + 1e-7) 81 out = torch.sum(out * indices.unsqueeze(0), 1) 82 loss += torch.sum((out - indices) ** 2) 83 return loss / (tensor1.shape[-1] * tensor1.shape[0])
Compute the loss.
Parameters
tensor1, tensor2 : torch.Tensor
tensor of shape (#batch, #features, #frames)
Returns
loss : float the loss value