dlc2action.loss.tcc
TCC loss (from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch).
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# Incorporates code adapted from Temporal Cycle-Consistency by June01 7# Original work Copyright (c) 2004 June01 8# Source: https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch 9# Originally licensed under Apache License Version 2.0, January 2004 10# Combined work licensed under GNU AGPLv3 11# 12"""TCC loss (from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch).""" 13 14import torch 15from torch import nn 16from torch.nn import functional as F 17 18 19class TCCLoss(nn.Module): 20 """Temporal Cycle Consistency Loss.""" 21 22 def __init__( 23 self, 24 loss_type: str = "regression_mse_var", 25 variance_lambda: float = 0.001, 26 normalize_indices: bool = True, 27 normalize_embeddings: bool = False, 28 similarity_type: str = "l2", 29 num_cycles: int = 20, 30 cycle_length: int = 2, 31 temperature: float = 0.1, 32 label_smoothing: float = 0.1, 33 ): 34 """Initialize the loss.""" 35 super().__init__() 36 self.loss_type = loss_type 37 self.variance_lambda = variance_lambda 38 self.normalize_indices = normalize_indices 39 self.normalize_embeddings = normalize_embeddings 40 self.similarity_type = similarity_type 41 self.num_cycles = num_cycles 42 self.cycle_length = cycle_length 43 self.temperature = temperature 44 self.label_smoothing = label_smoothing 45 46 def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float: 47 """Forward pass.""" 48 real_lens = mask.sum(-1).squeeze().cpu() 49 if len(real_lens.shape) == 0: 50 return torch.tensor(0) 51 return compute_alignment_loss( 52 predictions.transpose(-1, -2), 53 normalize_embeddings=self.normalize_embeddings, 54 normalize_indices=self.normalize_indices, 55 loss_type=self.loss_type, 56 similarity_type=self.similarity_type, 57 num_cycles=self.num_cycles, 58 cycle_length=self.cycle_length, 59 temperature=self.temperature, 60 label_smoothing=self.label_smoothing, 61 variance_lambda=self.variance_lambda, 62 real_lens=real_lens, 63 ) 64 65 66def _align_single_cycle( 67 cycle, embs, cycle_length, num_steps, real_len, similarity_type, temperature 68): 69 # choose from random frame 70 n_idx = (torch.rand(1) * real_len).long()[0] 71 # n_idx = torch.tensor(8).long() 72 73 # Create labels 74 onehot_labels = torch.eye(num_steps)[n_idx] 75 76 # Choose query feats for first frame. 77 query_feats = embs[cycle[0], n_idx : n_idx + 1] 78 num_channels = query_feats.size(-1) 79 for c in range(1, cycle_length + 1): 80 candidate_feats = embs[cycle[c]] 81 if similarity_type == "l2": 82 mean_squared_distance = torch.sum( 83 (query_feats.repeat([num_steps, 1]) - candidate_feats) ** 2, dim=1 84 ) 85 similarity = -mean_squared_distance 86 elif similarity_type == "cosine": 87 similarity = torch.squeeze( 88 torch.matmul(candidate_feats, query_feats.transpose(0, 1)) 89 ) 90 else: 91 raise ValueError("similarity_type can either be l2 or cosine.") 92 93 similarity /= float(num_channels) 94 similarity /= temperature 95 96 beta = F.softmax(similarity, dim=0).unsqueeze(1).repeat([1, num_channels]) 97 query_feats = torch.sum(beta * candidate_feats, dim=0, keepdim=True) 98 99 return similarity.unsqueeze(0), onehot_labels.unsqueeze(0) 100 101 102def _align( 103 cycles, 104 embs, 105 num_steps, 106 real_lens, 107 num_cycles, 108 cycle_length, 109 similarity_type, 110 temperature, 111 batch_size, 112): 113 """Align by finding cycles in embs.""" 114 logits_list = [] 115 labels_list = [] 116 for i in range(num_cycles): 117 if len(real_lens) == batch_size: 118 real_len = int(real_lens[cycles[i][0]]) 119 else: 120 real_len = int(real_lens[cycles[i][0] // batch_size]) 121 logits, labels = _align_single_cycle( 122 cycles[i], 123 embs, 124 cycle_length, 125 num_steps, 126 real_len, 127 similarity_type, 128 temperature, 129 ) 130 logits_list.append(logits) 131 labels_list.append(labels) 132 133 logits = torch.cat(logits_list, dim=0) 134 labels = torch.cat(labels_list, dim=0).to(embs.device) 135 136 return logits, labels 137 138 139def gen_cycles(num_cycles, batch_size, cycle_length=2): 140 """Generate cycles for alignment. 141 142 Generates a batch of indices to cycle over. For example setting num_cycles=2, 143 batch_size=5, cycle_length=3 might return something like this: 144 cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which 145 the loss will be calculated. The first cycle starts at sequence 0 of the 146 batch, then we find a matching step in sequence 3 of that batch, then we 147 find matching step in sequence 4 and finally come back to sequence 0, 148 completing a cycle. 149 150 Parameters 151 ---------- 152 num_cycles : int 153 Number of cycles that will be matched in one pass. 154 batch_size : int 155 Number of sequences in one batch. 156 cycle_length : int 157 Length of the cycles. If we are matching between 2 sequences (cycle_length=2), 158 we get cycles that look like [0,1,0]. This means that we go from sequence 0 159 to sequence 1 then back to sequence 0. A cycle length of 3 might look like 160 [0, 1, 2, 0]. 161 162 Returns 163 ------- 164 cycles : torch.Tensor 165 Batch indices denoting cycles that will be used for calculating the alignment loss. 166 167 """ 168 sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1]) 169 sorted_idxes = sorted_idxes.view([batch_size, num_cycles]) 170 cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view( 171 [num_cycles, batch_size] 172 ) 173 cycles = cycles[:, :cycle_length] 174 cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1) 175 176 return cycles 177 178 179def compute_stochastic_alignment_loss( 180 embs, 181 steps, 182 seq_lens, 183 num_steps, 184 batch_size, 185 loss_type, 186 similarity_type, 187 num_cycles, 188 cycle_length, 189 temperature, 190 variance_lambda, 191 normalize_indices, 192 real_lens, 193): 194 """Compute stochastic alignment loss. 195 196 Parameters 197 ---------- 198 embs : torch.Tensor 199 Embeddings of shape (batch_size, num_steps, num_channels). 200 steps : torch.Tensor 201 Steps of shape (batch_size, num_steps). 202 seq_lens : torch.Tensor 203 Sequence lengths of shape (batch_size,). 204 num_steps : int 205 Number of steps in the sequence. 206 batch_size : int 207 Batch size. 208 loss_type : str 209 Type of loss to use. Can be either "regression" or "classification". 210 similarity_type : str 211 Type of similarity to use. Can be either "l2" or "cosine". 212 num_cycles : int 213 Number of cycles to use for alignment. 214 cycle_length : int 215 Length of cycles to use for alignment. 216 temperature : float 217 Temperature to use for alignment. 218 variance_lambda : float 219 Lambda to use for variance regularization. 220 normalize_indices : bool 221 Whether to normalize indices. 222 real_lens : torch.Tensor 223 Real lengths of sequences. 224 225 Returns 226 ------- 227 loss : torch.Tensor 228 Alignment loss. 229 230 """ 231 cycles = gen_cycles(num_cycles, batch_size, cycle_length).to(embs.device) 232 logits, labels = _align( 233 cycles=cycles, 234 embs=embs, 235 num_steps=num_steps, 236 real_lens=real_lens, 237 num_cycles=num_cycles, 238 cycle_length=cycle_length, 239 similarity_type=similarity_type, 240 temperature=temperature, 241 batch_size=batch_size, 242 ) 243 244 if "regression" in loss_type: 245 steps = steps[cycles[:, 0]] 246 seq_lens = seq_lens[cycles[:, 0]] 247 loss = regression_loss( 248 logits, 249 labels, 250 num_steps, 251 steps, 252 seq_lens, 253 loss_type, 254 normalize_indices, 255 variance_lambda, 256 ) 257 else: 258 raise ValueError( 259 "Unidentified loss type %s. Currently supported loss " 260 "types are: regression_mse, regression_mse_var, " % loss_type 261 ) 262 return loss 263 264 265def compute_alignment_loss( 266 embs, 267 real_lens, 268 steps=None, 269 seq_lens=None, 270 normalize_embeddings=False, 271 loss_type="classification", 272 similarity_type="l2", 273 num_cycles=20, 274 cycle_length=2, 275 temperature=0.1, 276 label_smoothing=0.1, 277 variance_lambda=0.001, 278 huber_delta=0.1, 279 normalize_indices=True, 280): 281 """Compute alignment loss. 282 283 Parameters 284 ---------- 285 embs : torch.Tensor 286 Sequence embeddings of shape (batch_size, num_steps, emb_dim) 287 real_lens : torch.Tensor 288 Length of each sequence in the batch of shape (batch_size,) 289 steps : torch.Tensor, optional 290 Step indices of shape (batch_size, num_steps), by default None 291 seq_lens : torch.Tensor, optional 292 Length of each sequence in the batch of shape (batch_size,), by default None 293 normalize_embeddings : bool, optional 294 Whether to normalize embeddings, by default False 295 loss_type : str, default "classification" 296 Type of loss to use, by default "classification" 297 similarity_type : str, default "l2" 298 Type of similarity to use 299 num_cycles : int, default 20 300 Number of cycles to use 301 cycle_length : int, default 2 302 Length of cycles to use 303 temperature : float, default 0.1 304 Temperature to use for softmax 305 label_smoothing : float, default 0.1 306 Label smoothing to use 307 variance_lambda : float, default 0.001 308 Variance lambda to use 309 huber_delta : float, default 0.1 310 Huber delta to use 311 normalize_indices : bool, default True 312 Whether to normalize indices 313 314 """ 315 # Get the number of timestamps in the sequence embeddings. 316 num_steps = embs.shape[1] 317 batch_size = embs.shape[0] 318 319 # If steps has not been provided assume sampling has been done uniformly. 320 if steps is None: 321 steps = ( 322 torch.arange(0, num_steps) 323 .unsqueeze(0) 324 .repeat([batch_size, 1]) 325 .to(embs.device) 326 ) 327 328 # If seq_lens has not been provided assume is equal to the size of the 329 # time axis in the emebeddings. 330 if seq_lens is None: 331 seq_lens = ( 332 torch.tensor(num_steps) 333 .unsqueeze(0) 334 .repeat([batch_size]) 335 .int() 336 .to(embs.device) 337 ) 338 339 # check if batch_size if consistent with emb etc 340 assert num_steps == steps.shape[1] 341 assert batch_size == steps.shape[0] 342 343 if normalize_embeddings: 344 embs = F.normalize(embs, dim=-1, p=2) 345 346 loss = compute_stochastic_alignment_loss( 347 embs=embs, 348 steps=steps, 349 seq_lens=seq_lens, 350 num_steps=num_steps, 351 batch_size=batch_size, 352 loss_type=loss_type, 353 similarity_type=similarity_type, 354 num_cycles=num_cycles, 355 cycle_length=cycle_length, 356 temperature=temperature, 357 # label_smoothing=label_smoothing, 358 variance_lambda=variance_lambda, 359 # huber_delta=huber_delta, 360 normalize_indices=normalize_indices, 361 real_lens=real_lens, 362 ) 363 364 return loss 365 366 367def regression_loss( 368 logits, 369 labels, 370 num_steps, 371 steps, 372 seq_lens, 373 loss_type, 374 normalize_indices, 375 variance_lambda, 376): 377 """Loss function based on regressing to the correct indices. 378 379 In the paper, this is called Cycle-back Regression. There are 3 variants 380 of this loss: 381 i) regression_mse: MSE of the predicted indices and ground truth indices. 382 ii) regression_mse_var: MSE of the predicted indices that takes into account 383 the variance of the similarities. This is important when the rate at which 384 sequences go through different phases changes a lot. The variance scaling 385 allows dynamic weighting of the MSE loss based on the similarities. 386 iii) regression_huber: Huber loss between the predicted indices and ground 387 truth indices. 388 389 Parameters 390 ---------- 391 logits : torch.Tensor 392 Pre-softmax similarity scores after cycling back to the starting sequence 393 of shape (batch_size, num_steps) 394 labels : torch.Tensor 395 One hot labels containing the ground truth. The index where the cycle 396 started is 1. Shape (batch_size, num_steps) 397 num_steps : int 398 Number of steps in the sequence embeddings 399 steps : torch.Tensor 400 Step indices/frame indices of the embeddings of the shape (batch_size, num_steps) 401 seq_lens : torch.Tensor 402 Lengths of the sequences from which the sampling was done. This can 403 provide additional temporal information to the alignment loss. 404 loss_type : str 405 This specifies the kind of regression loss function. Currently supported 406 loss functions: regression_mse, regression_mse_var, regression_huber. 407 normalize_indices : bool 408 If True, normalizes indices by sequence lengths. Useful for ensuring 409 numerical instabilities don't arise as sequence indices can be large 410 numbers. 411 variance_lambda : float 412 Weight of the variance of the similarity predictions while cycling back 413 to the starting sequence. 414 415 Returns 416 ------- 417 loss : torch.Tensor 418 A scalar loss calculated using a variant of regression. 419 420 """ 421 # Just to be safe, we stop gradients from labels as we are generating labels. 422 labels = labels.detach() 423 steps = steps.detach() 424 425 if normalize_indices: 426 float_seq_lens = seq_lens.float() 427 tile_seq_lens = ( 428 torch.tile(torch.unsqueeze(float_seq_lens, dim=1), [1, num_steps]) + 1e-7 429 ) 430 steps = steps.float() / tile_seq_lens 431 else: 432 steps = steps.float() 433 434 beta = F.softmax(logits, dim=1) 435 true_time = torch.sum(steps * labels, dim=1) 436 pred_time = torch.sum(steps * beta, dim=1) 437 438 if loss_type in ["regression_mse", "regression_mse_var"]: 439 if "var" in loss_type: 440 # Variance aware regression. 441 pred_time_tiled = torch.tile( 442 torch.unsqueeze(pred_time, dim=1), [1, num_steps] 443 ) 444 445 pred_time_variance = torch.sum( 446 ((steps - pred_time_tiled) ** 2) * beta, dim=1 447 ) 448 449 # Using log of variance as it is numerically stabler. 450 pred_time_log_var = torch.log(pred_time_variance + 1e-7) 451 squared_error = (true_time - pred_time) ** 2 452 return torch.mean( 453 torch.exp(-pred_time_log_var) * squared_error 454 + variance_lambda * pred_time_log_var 455 ) 456 457 else: 458 return torch.mean((true_time - pred_time) ** 2) 459 else: 460 raise ValueError( 461 "Unsupported regression loss %s. Supported losses are: " 462 "regression_mse, regresstion_mse_var." % loss_type 463 )
20class TCCLoss(nn.Module): 21 """Temporal Cycle Consistency Loss.""" 22 23 def __init__( 24 self, 25 loss_type: str = "regression_mse_var", 26 variance_lambda: float = 0.001, 27 normalize_indices: bool = True, 28 normalize_embeddings: bool = False, 29 similarity_type: str = "l2", 30 num_cycles: int = 20, 31 cycle_length: int = 2, 32 temperature: float = 0.1, 33 label_smoothing: float = 0.1, 34 ): 35 """Initialize the loss.""" 36 super().__init__() 37 self.loss_type = loss_type 38 self.variance_lambda = variance_lambda 39 self.normalize_indices = normalize_indices 40 self.normalize_embeddings = normalize_embeddings 41 self.similarity_type = similarity_type 42 self.num_cycles = num_cycles 43 self.cycle_length = cycle_length 44 self.temperature = temperature 45 self.label_smoothing = label_smoothing 46 47 def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float: 48 """Forward pass.""" 49 real_lens = mask.sum(-1).squeeze().cpu() 50 if len(real_lens.shape) == 0: 51 return torch.tensor(0) 52 return compute_alignment_loss( 53 predictions.transpose(-1, -2), 54 normalize_embeddings=self.normalize_embeddings, 55 normalize_indices=self.normalize_indices, 56 loss_type=self.loss_type, 57 similarity_type=self.similarity_type, 58 num_cycles=self.num_cycles, 59 cycle_length=self.cycle_length, 60 temperature=self.temperature, 61 label_smoothing=self.label_smoothing, 62 variance_lambda=self.variance_lambda, 63 real_lens=real_lens, 64 )
Temporal Cycle Consistency Loss.
23 def __init__( 24 self, 25 loss_type: str = "regression_mse_var", 26 variance_lambda: float = 0.001, 27 normalize_indices: bool = True, 28 normalize_embeddings: bool = False, 29 similarity_type: str = "l2", 30 num_cycles: int = 20, 31 cycle_length: int = 2, 32 temperature: float = 0.1, 33 label_smoothing: float = 0.1, 34 ): 35 """Initialize the loss.""" 36 super().__init__() 37 self.loss_type = loss_type 38 self.variance_lambda = variance_lambda 39 self.normalize_indices = normalize_indices 40 self.normalize_embeddings = normalize_embeddings 41 self.similarity_type = similarity_type 42 self.num_cycles = num_cycles 43 self.cycle_length = cycle_length 44 self.temperature = temperature 45 self.label_smoothing = label_smoothing
Initialize the loss.
47 def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float: 48 """Forward pass.""" 49 real_lens = mask.sum(-1).squeeze().cpu() 50 if len(real_lens.shape) == 0: 51 return torch.tensor(0) 52 return compute_alignment_loss( 53 predictions.transpose(-1, -2), 54 normalize_embeddings=self.normalize_embeddings, 55 normalize_indices=self.normalize_indices, 56 loss_type=self.loss_type, 57 similarity_type=self.similarity_type, 58 num_cycles=self.num_cycles, 59 cycle_length=self.cycle_length, 60 temperature=self.temperature, 61 label_smoothing=self.label_smoothing, 62 variance_lambda=self.variance_lambda, 63 real_lens=real_lens, 64 )
Forward pass.
140def gen_cycles(num_cycles, batch_size, cycle_length=2): 141 """Generate cycles for alignment. 142 143 Generates a batch of indices to cycle over. For example setting num_cycles=2, 144 batch_size=5, cycle_length=3 might return something like this: 145 cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which 146 the loss will be calculated. The first cycle starts at sequence 0 of the 147 batch, then we find a matching step in sequence 3 of that batch, then we 148 find matching step in sequence 4 and finally come back to sequence 0, 149 completing a cycle. 150 151 Parameters 152 ---------- 153 num_cycles : int 154 Number of cycles that will be matched in one pass. 155 batch_size : int 156 Number of sequences in one batch. 157 cycle_length : int 158 Length of the cycles. If we are matching between 2 sequences (cycle_length=2), 159 we get cycles that look like [0,1,0]. This means that we go from sequence 0 160 to sequence 1 then back to sequence 0. A cycle length of 3 might look like 161 [0, 1, 2, 0]. 162 163 Returns 164 ------- 165 cycles : torch.Tensor 166 Batch indices denoting cycles that will be used for calculating the alignment loss. 167 168 """ 169 sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1]) 170 sorted_idxes = sorted_idxes.view([batch_size, num_cycles]) 171 cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view( 172 [num_cycles, batch_size] 173 ) 174 cycles = cycles[:, :cycle_length] 175 cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1) 176 177 return cycles
Generate cycles for alignment.
Generates a batch of indices to cycle over. For example setting num_cycles=2, batch_size=5, cycle_length=3 might return something like this: cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which the loss will be calculated. The first cycle starts at sequence 0 of the batch, then we find a matching step in sequence 3 of that batch, then we find matching step in sequence 4 and finally come back to sequence 0, completing a cycle.
Parameters
num_cycles : int Number of cycles that will be matched in one pass. batch_size : int Number of sequences in one batch. cycle_length : int Length of the cycles. If we are matching between 2 sequences (cycle_length=2), we get cycles that look like [0,1,0]. This means that we go from sequence 0 to sequence 1 then back to sequence 0. A cycle length of 3 might look like [0, 1, 2, 0].
Returns
cycles : torch.Tensor Batch indices denoting cycles that will be used for calculating the alignment loss.
180def compute_stochastic_alignment_loss( 181 embs, 182 steps, 183 seq_lens, 184 num_steps, 185 batch_size, 186 loss_type, 187 similarity_type, 188 num_cycles, 189 cycle_length, 190 temperature, 191 variance_lambda, 192 normalize_indices, 193 real_lens, 194): 195 """Compute stochastic alignment loss. 196 197 Parameters 198 ---------- 199 embs : torch.Tensor 200 Embeddings of shape (batch_size, num_steps, num_channels). 201 steps : torch.Tensor 202 Steps of shape (batch_size, num_steps). 203 seq_lens : torch.Tensor 204 Sequence lengths of shape (batch_size,). 205 num_steps : int 206 Number of steps in the sequence. 207 batch_size : int 208 Batch size. 209 loss_type : str 210 Type of loss to use. Can be either "regression" or "classification". 211 similarity_type : str 212 Type of similarity to use. Can be either "l2" or "cosine". 213 num_cycles : int 214 Number of cycles to use for alignment. 215 cycle_length : int 216 Length of cycles to use for alignment. 217 temperature : float 218 Temperature to use for alignment. 219 variance_lambda : float 220 Lambda to use for variance regularization. 221 normalize_indices : bool 222 Whether to normalize indices. 223 real_lens : torch.Tensor 224 Real lengths of sequences. 225 226 Returns 227 ------- 228 loss : torch.Tensor 229 Alignment loss. 230 231 """ 232 cycles = gen_cycles(num_cycles, batch_size, cycle_length).to(embs.device) 233 logits, labels = _align( 234 cycles=cycles, 235 embs=embs, 236 num_steps=num_steps, 237 real_lens=real_lens, 238 num_cycles=num_cycles, 239 cycle_length=cycle_length, 240 similarity_type=similarity_type, 241 temperature=temperature, 242 batch_size=batch_size, 243 ) 244 245 if "regression" in loss_type: 246 steps = steps[cycles[:, 0]] 247 seq_lens = seq_lens[cycles[:, 0]] 248 loss = regression_loss( 249 logits, 250 labels, 251 num_steps, 252 steps, 253 seq_lens, 254 loss_type, 255 normalize_indices, 256 variance_lambda, 257 ) 258 else: 259 raise ValueError( 260 "Unidentified loss type %s. Currently supported loss " 261 "types are: regression_mse, regression_mse_var, " % loss_type 262 ) 263 return loss
Compute stochastic alignment loss.
Parameters
embs : torch.Tensor Embeddings of shape (batch_size, num_steps, num_channels). steps : torch.Tensor Steps of shape (batch_size, num_steps). seq_lens : torch.Tensor Sequence lengths of shape (batch_size,). num_steps : int Number of steps in the sequence. batch_size : int Batch size. loss_type : str Type of loss to use. Can be either "regression" or "classification". similarity_type : str Type of similarity to use. Can be either "l2" or "cosine". num_cycles : int Number of cycles to use for alignment. cycle_length : int Length of cycles to use for alignment. temperature : float Temperature to use for alignment. variance_lambda : float Lambda to use for variance regularization. normalize_indices : bool Whether to normalize indices. real_lens : torch.Tensor Real lengths of sequences.
Returns
loss : torch.Tensor Alignment loss.
266def compute_alignment_loss( 267 embs, 268 real_lens, 269 steps=None, 270 seq_lens=None, 271 normalize_embeddings=False, 272 loss_type="classification", 273 similarity_type="l2", 274 num_cycles=20, 275 cycle_length=2, 276 temperature=0.1, 277 label_smoothing=0.1, 278 variance_lambda=0.001, 279 huber_delta=0.1, 280 normalize_indices=True, 281): 282 """Compute alignment loss. 283 284 Parameters 285 ---------- 286 embs : torch.Tensor 287 Sequence embeddings of shape (batch_size, num_steps, emb_dim) 288 real_lens : torch.Tensor 289 Length of each sequence in the batch of shape (batch_size,) 290 steps : torch.Tensor, optional 291 Step indices of shape (batch_size, num_steps), by default None 292 seq_lens : torch.Tensor, optional 293 Length of each sequence in the batch of shape (batch_size,), by default None 294 normalize_embeddings : bool, optional 295 Whether to normalize embeddings, by default False 296 loss_type : str, default "classification" 297 Type of loss to use, by default "classification" 298 similarity_type : str, default "l2" 299 Type of similarity to use 300 num_cycles : int, default 20 301 Number of cycles to use 302 cycle_length : int, default 2 303 Length of cycles to use 304 temperature : float, default 0.1 305 Temperature to use for softmax 306 label_smoothing : float, default 0.1 307 Label smoothing to use 308 variance_lambda : float, default 0.001 309 Variance lambda to use 310 huber_delta : float, default 0.1 311 Huber delta to use 312 normalize_indices : bool, default True 313 Whether to normalize indices 314 315 """ 316 # Get the number of timestamps in the sequence embeddings. 317 num_steps = embs.shape[1] 318 batch_size = embs.shape[0] 319 320 # If steps has not been provided assume sampling has been done uniformly. 321 if steps is None: 322 steps = ( 323 torch.arange(0, num_steps) 324 .unsqueeze(0) 325 .repeat([batch_size, 1]) 326 .to(embs.device) 327 ) 328 329 # If seq_lens has not been provided assume is equal to the size of the 330 # time axis in the emebeddings. 331 if seq_lens is None: 332 seq_lens = ( 333 torch.tensor(num_steps) 334 .unsqueeze(0) 335 .repeat([batch_size]) 336 .int() 337 .to(embs.device) 338 ) 339 340 # check if batch_size if consistent with emb etc 341 assert num_steps == steps.shape[1] 342 assert batch_size == steps.shape[0] 343 344 if normalize_embeddings: 345 embs = F.normalize(embs, dim=-1, p=2) 346 347 loss = compute_stochastic_alignment_loss( 348 embs=embs, 349 steps=steps, 350 seq_lens=seq_lens, 351 num_steps=num_steps, 352 batch_size=batch_size, 353 loss_type=loss_type, 354 similarity_type=similarity_type, 355 num_cycles=num_cycles, 356 cycle_length=cycle_length, 357 temperature=temperature, 358 # label_smoothing=label_smoothing, 359 variance_lambda=variance_lambda, 360 # huber_delta=huber_delta, 361 normalize_indices=normalize_indices, 362 real_lens=real_lens, 363 ) 364 365 return loss
Compute alignment loss.
Parameters
embs : torch.Tensor Sequence embeddings of shape (batch_size, num_steps, emb_dim) real_lens : torch.Tensor Length of each sequence in the batch of shape (batch_size,) steps : torch.Tensor, optional Step indices of shape (batch_size, num_steps), by default None seq_lens : torch.Tensor, optional Length of each sequence in the batch of shape (batch_size,), by default None normalize_embeddings : bool, optional Whether to normalize embeddings, by default False loss_type : str, default "classification" Type of loss to use, by default "classification" similarity_type : str, default "l2" Type of similarity to use num_cycles : int, default 20 Number of cycles to use cycle_length : int, default 2 Length of cycles to use temperature : float, default 0.1 Temperature to use for softmax label_smoothing : float, default 0.1 Label smoothing to use variance_lambda : float, default 0.001 Variance lambda to use huber_delta : float, default 0.1 Huber delta to use normalize_indices : bool, default True Whether to normalize indices
368def regression_loss( 369 logits, 370 labels, 371 num_steps, 372 steps, 373 seq_lens, 374 loss_type, 375 normalize_indices, 376 variance_lambda, 377): 378 """Loss function based on regressing to the correct indices. 379 380 In the paper, this is called Cycle-back Regression. There are 3 variants 381 of this loss: 382 i) regression_mse: MSE of the predicted indices and ground truth indices. 383 ii) regression_mse_var: MSE of the predicted indices that takes into account 384 the variance of the similarities. This is important when the rate at which 385 sequences go through different phases changes a lot. The variance scaling 386 allows dynamic weighting of the MSE loss based on the similarities. 387 iii) regression_huber: Huber loss between the predicted indices and ground 388 truth indices. 389 390 Parameters 391 ---------- 392 logits : torch.Tensor 393 Pre-softmax similarity scores after cycling back to the starting sequence 394 of shape (batch_size, num_steps) 395 labels : torch.Tensor 396 One hot labels containing the ground truth. The index where the cycle 397 started is 1. Shape (batch_size, num_steps) 398 num_steps : int 399 Number of steps in the sequence embeddings 400 steps : torch.Tensor 401 Step indices/frame indices of the embeddings of the shape (batch_size, num_steps) 402 seq_lens : torch.Tensor 403 Lengths of the sequences from which the sampling was done. This can 404 provide additional temporal information to the alignment loss. 405 loss_type : str 406 This specifies the kind of regression loss function. Currently supported 407 loss functions: regression_mse, regression_mse_var, regression_huber. 408 normalize_indices : bool 409 If True, normalizes indices by sequence lengths. Useful for ensuring 410 numerical instabilities don't arise as sequence indices can be large 411 numbers. 412 variance_lambda : float 413 Weight of the variance of the similarity predictions while cycling back 414 to the starting sequence. 415 416 Returns 417 ------- 418 loss : torch.Tensor 419 A scalar loss calculated using a variant of regression. 420 421 """ 422 # Just to be safe, we stop gradients from labels as we are generating labels. 423 labels = labels.detach() 424 steps = steps.detach() 425 426 if normalize_indices: 427 float_seq_lens = seq_lens.float() 428 tile_seq_lens = ( 429 torch.tile(torch.unsqueeze(float_seq_lens, dim=1), [1, num_steps]) + 1e-7 430 ) 431 steps = steps.float() / tile_seq_lens 432 else: 433 steps = steps.float() 434 435 beta = F.softmax(logits, dim=1) 436 true_time = torch.sum(steps * labels, dim=1) 437 pred_time = torch.sum(steps * beta, dim=1) 438 439 if loss_type in ["regression_mse", "regression_mse_var"]: 440 if "var" in loss_type: 441 # Variance aware regression. 442 pred_time_tiled = torch.tile( 443 torch.unsqueeze(pred_time, dim=1), [1, num_steps] 444 ) 445 446 pred_time_variance = torch.sum( 447 ((steps - pred_time_tiled) ** 2) * beta, dim=1 448 ) 449 450 # Using log of variance as it is numerically stabler. 451 pred_time_log_var = torch.log(pred_time_variance + 1e-7) 452 squared_error = (true_time - pred_time) ** 2 453 return torch.mean( 454 torch.exp(-pred_time_log_var) * squared_error 455 + variance_lambda * pred_time_log_var 456 ) 457 458 else: 459 return torch.mean((true_time - pred_time) ** 2) 460 else: 461 raise ValueError( 462 "Unsupported regression loss %s. Supported losses are: " 463 "regression_mse, regresstion_mse_var." % loss_type 464 )
Loss function based on regressing to the correct indices.
In the paper, this is called Cycle-back Regression. There are 3 variants of this loss: i) regression_mse: MSE of the predicted indices and ground truth indices. ii) regression_mse_var: MSE of the predicted indices that takes into account the variance of the similarities. This is important when the rate at which sequences go through different phases changes a lot. The variance scaling allows dynamic weighting of the MSE loss based on the similarities. iii) regression_huber: Huber loss between the predicted indices and ground truth indices.
Parameters
logits : torch.Tensor Pre-softmax similarity scores after cycling back to the starting sequence of shape (batch_size, num_steps) labels : torch.Tensor One hot labels containing the ground truth. The index where the cycle started is 1. Shape (batch_size, num_steps) num_steps : int Number of steps in the sequence embeddings steps : torch.Tensor Step indices/frame indices of the embeddings of the shape (batch_size, num_steps) seq_lens : torch.Tensor Lengths of the sequences from which the sampling was done. This can provide additional temporal information to the alignment loss. loss_type : str This specifies the kind of regression loss function. Currently supported loss functions: regression_mse, regression_mse_var, regression_huber. normalize_indices : bool If True, normalizes indices by sequence lengths. Useful for ensuring numerical instabilities don't arise as sequence indices can be large numbers. variance_lambda : float Weight of the variance of the similarity predictions while cycling back to the starting sequence.
Returns
loss : torch.Tensor A scalar loss calculated using a variant of regression.