dlc2action.loss.tcc
TCC loss
Adapted from from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# 7# Adapted from Temporal Cycle-Consistency Learning by June01 8# Adapted from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch 9# Licensed under Apache License Version 2.0, January 2004 10# 11""" TCC loss 12 13Adapted from from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch 14 15""" 16 17import torch 18from torch.nn import functional as F 19from torch import nn 20 21 22class _TCCLoss(nn.Module): 23 def __init__( 24 self, 25 loss_type: str = "regression_mse_var", 26 variance_lambda: float = 0.001, 27 normalize_indices: bool = True, 28 normalize_embeddings: bool = False, 29 similarity_type: str = "l2", 30 num_cycles: int = 20, 31 cycle_length: int = 2, 32 temperature: float = 0.1, 33 label_smoothing: float = 0.1, 34 ): 35 super().__init__() 36 self.loss_type = loss_type 37 self.variance_lambda = variance_lambda 38 self.normalize_indices = normalize_indices 39 self.normalize_embeddings = normalize_embeddings 40 self.similarity_type = similarity_type 41 self.num_cycles = num_cycles 42 self.cycle_length = cycle_length 43 self.temperature = temperature 44 self.label_smoothing = label_smoothing 45 46 def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float: 47 real_lens = mask.sum(-1).squeeze().cpu() 48 if len(real_lens.shape) == 0: 49 return torch.tensor(0) 50 return compute_alignment_loss( 51 predictions.transpose(-1, -2), 52 normalize_embeddings=self.normalize_embeddings, 53 normalize_indices=self.normalize_indices, 54 loss_type=self.loss_type, 55 similarity_type=self.similarity_type, 56 num_cycles=self.num_cycles, 57 cycle_length=self.cycle_length, 58 temperature=self.temperature, 59 label_smoothing=self.label_smoothing, 60 variance_lambda=self.variance_lambda, 61 real_lens=real_lens, 62 ) 63 64 65def _align_single_cycle( 66 cycle, embs, cycle_length, num_steps, real_len, similarity_type, temperature 67): 68 # choose from random frame 69 n_idx = (torch.rand(1) * real_len).long()[0] 70 # n_idx = torch.tensor(8).long() 71 72 # Create labels 73 onehot_labels = torch.eye(num_steps)[n_idx] 74 75 # Choose query feats for first frame. 76 query_feats = embs[cycle[0], n_idx : n_idx + 1] 77 num_channels = query_feats.size(-1) 78 for c in range(1, cycle_length + 1): 79 candidate_feats = embs[cycle[c]] 80 if similarity_type == "l2": 81 mean_squared_distance = torch.sum( 82 (query_feats.repeat([num_steps, 1]) - candidate_feats) ** 2, dim=1 83 ) 84 similarity = -mean_squared_distance 85 elif similarity_type == "cosine": 86 similarity = torch.squeeze( 87 torch.matmul(candidate_feats, query_feats.transpose(0, 1)) 88 ) 89 else: 90 raise ValueError("similarity_type can either be l2 or cosine.") 91 92 similarity /= float(num_channels) 93 similarity /= temperature 94 95 beta = F.softmax(similarity, dim=0).unsqueeze(1).repeat([1, num_channels]) 96 query_feats = torch.sum(beta * candidate_feats, dim=0, keepdim=True) 97 98 return similarity.unsqueeze(0), onehot_labels.unsqueeze(0) 99 100 101def _align( 102 cycles, 103 embs, 104 num_steps, 105 real_lens, 106 num_cycles, 107 cycle_length, 108 similarity_type, 109 temperature, 110 batch_size, 111): 112 """Align by finding cycles in embs.""" 113 logits_list = [] 114 labels_list = [] 115 for i in range(num_cycles): 116 if len(real_lens) == batch_size: 117 real_len = int(real_lens[cycles[i][0]]) 118 else: 119 real_len = int(real_lens[cycles[i][0] // batch_size]) 120 logits, labels = _align_single_cycle( 121 cycles[i], 122 embs, 123 cycle_length, 124 num_steps, 125 real_len, 126 similarity_type, 127 temperature, 128 ) 129 logits_list.append(logits) 130 labels_list.append(labels) 131 132 logits = torch.cat(logits_list, dim=0) 133 labels = torch.cat(labels_list, dim=0).to(embs.device) 134 135 return logits, labels 136 137 138def gen_cycles(num_cycles, batch_size, cycle_length=2): 139 """Generates cycles for alignment. 140 Generates a batch of indices to cycle over. For example setting num_cycles=2, 141 batch_size=5, cycle_length=3 might return something like this: 142 cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which 143 the loss will be calculated. The first cycle starts at sequence 0 of the 144 batch, then we find a matching step in sequence 3 of that batch, then we 145 find matching step in sequence 4 and finally come back to sequence 0, 146 completing a cycle. 147 Args: 148 num_cycles: Integer, Number of cycles that will be matched in one pass. 149 batch_size: Integer, Number of sequences in one batch. 150 cycle_length: Integer, Length of the cycles. If we are matching between 151 2 sequences (cycle_length=2), we get cycles that look like [0,1,0]. 152 This means that we go from sequence 0 to sequence 1 then back to sequence 153 0. A cycle length of 3 might look like [0, 1, 2, 0]. 154 Returns: 155 cycles: Tensor, Batch indices denoting cycles that will be used for 156 calculating the alignment loss. 157 """ 158 sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1]) 159 sorted_idxes = sorted_idxes.view([batch_size, num_cycles]) 160 cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view( 161 [num_cycles, batch_size] 162 ) 163 cycles = cycles[:, :cycle_length] 164 cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1) 165 166 return cycles 167 168 169def compute_stochastic_alignment_loss( 170 embs, 171 steps, 172 seq_lens, 173 num_steps, 174 batch_size, 175 loss_type, 176 similarity_type, 177 num_cycles, 178 cycle_length, 179 temperature, 180 label_smoothing, 181 variance_lambda, 182 huber_delta, 183 normalize_indices, 184 real_lens, 185): 186 187 cycles = gen_cycles(num_cycles, batch_size, cycle_length).to(embs.device) 188 logits, labels = _align( 189 cycles=cycles, 190 embs=embs, 191 num_steps=num_steps, 192 real_lens=real_lens, 193 num_cycles=num_cycles, 194 cycle_length=cycle_length, 195 similarity_type=similarity_type, 196 temperature=temperature, 197 batch_size=batch_size, 198 ) 199 200 if "regression" in loss_type: 201 steps = steps[cycles[:, 0]] 202 seq_lens = seq_lens[cycles[:, 0]] 203 loss = regression_loss( 204 logits, 205 labels, 206 num_steps, 207 steps, 208 seq_lens, 209 loss_type, 210 normalize_indices, 211 variance_lambda, 212 ) 213 else: 214 raise ValueError( 215 "Unidentified loss type %s. Currently supported loss " 216 "types are: regression_mse, regression_mse_var, " % loss_type 217 ) 218 return loss 219 220 221def compute_alignment_loss( 222 embs, 223 real_lens, 224 steps=None, 225 seq_lens=None, 226 normalize_embeddings=False, 227 loss_type="classification", 228 similarity_type="l2", 229 num_cycles=20, 230 cycle_length=2, 231 temperature=0.1, 232 label_smoothing=0.1, 233 variance_lambda=0.001, 234 huber_delta=0.1, 235 normalize_indices=True, 236): 237 238 # Get the number of timestemps in the sequence embeddings. 239 num_steps = embs.shape[1] 240 batch_size = embs.shape[0] 241 242 # If steps has not been provided assume sampling has been done uniformly. 243 if steps is None: 244 steps = ( 245 torch.arange(0, num_steps) 246 .unsqueeze(0) 247 .repeat([batch_size, 1]) 248 .to(embs.device) 249 ) 250 251 # If seq_lens has not been provided assume is equal to the size of the 252 # time axis in the emebeddings. 253 if seq_lens is None: 254 seq_lens = ( 255 torch.tensor(num_steps) 256 .unsqueeze(0) 257 .repeat([batch_size]) 258 .int() 259 .to(embs.device) 260 ) 261 262 # check if batch_size if consistent with emb etc 263 assert num_steps == steps.shape[1] 264 assert batch_size == steps.shape[0] 265 266 if normalize_embeddings: 267 embs = F.normalize(embs, dim=-1, p=2) 268 269 loss = compute_stochastic_alignment_loss( 270 embs=embs, 271 steps=steps, 272 seq_lens=seq_lens, 273 num_steps=num_steps, 274 batch_size=batch_size, 275 loss_type=loss_type, 276 similarity_type=similarity_type, 277 num_cycles=num_cycles, 278 cycle_length=cycle_length, 279 temperature=temperature, 280 label_smoothing=label_smoothing, 281 variance_lambda=variance_lambda, 282 huber_delta=huber_delta, 283 normalize_indices=normalize_indices, 284 real_lens=real_lens, 285 ) 286 287 return loss 288 289 290def regression_loss( 291 logits, 292 labels, 293 num_steps, 294 steps, 295 seq_lens, 296 loss_type, 297 normalize_indices, 298 variance_lambda, 299): 300 """Loss function based on regressing to the correct indices. 301 In the paper, this is called Cycle-back Regression. There are 3 variants 302 of this loss: 303 i) regression_mse: MSE of the predicted indices and ground truth indices. 304 ii) regression_mse_var: MSE of the predicted indices that takes into account 305 the variance of the similarities. This is important when the rate at which 306 sequences go through different phases changes a lot. The variance scaling 307 allows dynamic weighting of the MSE loss based on the similarities. 308 iii) regression_huber: Huber loss between the predicted indices and ground 309 truth indices. 310 Args: 311 logits: Tensor, Pre-softmax similarity scores after cycling back to the 312 starting sequence. 313 labels: Tensor, One hot labels containing the ground truth. The index where 314 the cycle started is 1. 315 num_steps: Integer, Number of steps in the sequence embeddings. 316 steps: Tensor, step indices/frame indices of the embeddings of the shape 317 [N, T] where N is the batch size, T is the number of the timesteps. 318 seq_lens: Tensor, Lengths of the sequences from which the sampling was done. 319 This can provide additional temporal information to the alignment loss. 320 loss_type: String, This specifies the kind of regression loss function. 321 Currently supported loss functions: regression_mse, regression_mse_var, 322 regression_huber. 323 normalize_indices: Boolean, If True, normalizes indices by sequence lengths. 324 Useful for ensuring numerical instabilities don't arise as sequence 325 indices can be large numbers. 326 variance_lambda: Float, Weight of the variance of the similarity 327 predictions while cycling back. If this is high then the low variance 328 similarities are preferred by the loss while making this term low results 329 in high variance of the similarities (more uniform/random matching). 330 Returns: 331 loss: Tensor, A scalar loss calculated using a variant of regression. 332 """ 333 # Just to be safe, we stop gradients from labels as we are generating labels. 334 labels = labels.detach() 335 steps = steps.detach() 336 337 if normalize_indices: 338 float_seq_lens = seq_lens.float() 339 tile_seq_lens = ( 340 torch.tile(torch.unsqueeze(float_seq_lens, dim=1), [1, num_steps]) + 1e-7 341 ) 342 steps = steps.float() / tile_seq_lens 343 else: 344 steps = steps.float() 345 346 beta = F.softmax(logits, dim=1) 347 true_time = torch.sum(steps * labels, dim=1) 348 pred_time = torch.sum(steps * beta, dim=1) 349 350 if loss_type in ["regression_mse", "regression_mse_var"]: 351 if "var" in loss_type: 352 # Variance aware regression. 353 pred_time_tiled = torch.tile( 354 torch.unsqueeze(pred_time, dim=1), [1, num_steps] 355 ) 356 357 pred_time_variance = torch.sum( 358 ((steps - pred_time_tiled) ** 2) * beta, dim=1 359 ) 360 361 # Using log of variance as it is numerically stabler. 362 pred_time_log_var = torch.log(pred_time_variance + 1e-7) 363 squared_error = (true_time - pred_time) ** 2 364 return torch.mean( 365 torch.exp(-pred_time_log_var) * squared_error 366 + variance_lambda * pred_time_log_var 367 ) 368 369 else: 370 return torch.mean((true_time - pred_time) ** 2) 371 else: 372 raise ValueError( 373 "Unsupported regression loss %s. Supported losses are: " 374 "regression_mse, regresstion_mse_var." % loss_type 375 )
139def gen_cycles(num_cycles, batch_size, cycle_length=2): 140 """Generates cycles for alignment. 141 Generates a batch of indices to cycle over. For example setting num_cycles=2, 142 batch_size=5, cycle_length=3 might return something like this: 143 cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which 144 the loss will be calculated. The first cycle starts at sequence 0 of the 145 batch, then we find a matching step in sequence 3 of that batch, then we 146 find matching step in sequence 4 and finally come back to sequence 0, 147 completing a cycle. 148 Args: 149 num_cycles: Integer, Number of cycles that will be matched in one pass. 150 batch_size: Integer, Number of sequences in one batch. 151 cycle_length: Integer, Length of the cycles. If we are matching between 152 2 sequences (cycle_length=2), we get cycles that look like [0,1,0]. 153 This means that we go from sequence 0 to sequence 1 then back to sequence 154 0. A cycle length of 3 might look like [0, 1, 2, 0]. 155 Returns: 156 cycles: Tensor, Batch indices denoting cycles that will be used for 157 calculating the alignment loss. 158 """ 159 sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1]) 160 sorted_idxes = sorted_idxes.view([batch_size, num_cycles]) 161 cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view( 162 [num_cycles, batch_size] 163 ) 164 cycles = cycles[:, :cycle_length] 165 cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1) 166 167 return cycles
Generates cycles for alignment. Generates a batch of indices to cycle over. For example setting num_cycles=2, batch_size=5, cycle_length=3 might return something like this: cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which the loss will be calculated. The first cycle starts at sequence 0 of the batch, then we find a matching step in sequence 3 of that batch, then we find matching step in sequence 4 and finally come back to sequence 0, completing a cycle. Args: num_cycles: Integer, Number of cycles that will be matched in one pass. batch_size: Integer, Number of sequences in one batch. cycle_length: Integer, Length of the cycles. If we are matching between 2 sequences (cycle_length=2), we get cycles that look like [0,1,0]. This means that we go from sequence 0 to sequence 1 then back to sequence
- A cycle length of 3 might look like [0, 1, 2, 0]. Returns: cycles: Tensor, Batch indices denoting cycles that will be used for calculating the alignment loss.
170def compute_stochastic_alignment_loss( 171 embs, 172 steps, 173 seq_lens, 174 num_steps, 175 batch_size, 176 loss_type, 177 similarity_type, 178 num_cycles, 179 cycle_length, 180 temperature, 181 label_smoothing, 182 variance_lambda, 183 huber_delta, 184 normalize_indices, 185 real_lens, 186): 187 188 cycles = gen_cycles(num_cycles, batch_size, cycle_length).to(embs.device) 189 logits, labels = _align( 190 cycles=cycles, 191 embs=embs, 192 num_steps=num_steps, 193 real_lens=real_lens, 194 num_cycles=num_cycles, 195 cycle_length=cycle_length, 196 similarity_type=similarity_type, 197 temperature=temperature, 198 batch_size=batch_size, 199 ) 200 201 if "regression" in loss_type: 202 steps = steps[cycles[:, 0]] 203 seq_lens = seq_lens[cycles[:, 0]] 204 loss = regression_loss( 205 logits, 206 labels, 207 num_steps, 208 steps, 209 seq_lens, 210 loss_type, 211 normalize_indices, 212 variance_lambda, 213 ) 214 else: 215 raise ValueError( 216 "Unidentified loss type %s. Currently supported loss " 217 "types are: regression_mse, regression_mse_var, " % loss_type 218 ) 219 return loss
222def compute_alignment_loss( 223 embs, 224 real_lens, 225 steps=None, 226 seq_lens=None, 227 normalize_embeddings=False, 228 loss_type="classification", 229 similarity_type="l2", 230 num_cycles=20, 231 cycle_length=2, 232 temperature=0.1, 233 label_smoothing=0.1, 234 variance_lambda=0.001, 235 huber_delta=0.1, 236 normalize_indices=True, 237): 238 239 # Get the number of timestemps in the sequence embeddings. 240 num_steps = embs.shape[1] 241 batch_size = embs.shape[0] 242 243 # If steps has not been provided assume sampling has been done uniformly. 244 if steps is None: 245 steps = ( 246 torch.arange(0, num_steps) 247 .unsqueeze(0) 248 .repeat([batch_size, 1]) 249 .to(embs.device) 250 ) 251 252 # If seq_lens has not been provided assume is equal to the size of the 253 # time axis in the emebeddings. 254 if seq_lens is None: 255 seq_lens = ( 256 torch.tensor(num_steps) 257 .unsqueeze(0) 258 .repeat([batch_size]) 259 .int() 260 .to(embs.device) 261 ) 262 263 # check if batch_size if consistent with emb etc 264 assert num_steps == steps.shape[1] 265 assert batch_size == steps.shape[0] 266 267 if normalize_embeddings: 268 embs = F.normalize(embs, dim=-1, p=2) 269 270 loss = compute_stochastic_alignment_loss( 271 embs=embs, 272 steps=steps, 273 seq_lens=seq_lens, 274 num_steps=num_steps, 275 batch_size=batch_size, 276 loss_type=loss_type, 277 similarity_type=similarity_type, 278 num_cycles=num_cycles, 279 cycle_length=cycle_length, 280 temperature=temperature, 281 label_smoothing=label_smoothing, 282 variance_lambda=variance_lambda, 283 huber_delta=huber_delta, 284 normalize_indices=normalize_indices, 285 real_lens=real_lens, 286 ) 287 288 return loss
291def regression_loss( 292 logits, 293 labels, 294 num_steps, 295 steps, 296 seq_lens, 297 loss_type, 298 normalize_indices, 299 variance_lambda, 300): 301 """Loss function based on regressing to the correct indices. 302 In the paper, this is called Cycle-back Regression. There are 3 variants 303 of this loss: 304 i) regression_mse: MSE of the predicted indices and ground truth indices. 305 ii) regression_mse_var: MSE of the predicted indices that takes into account 306 the variance of the similarities. This is important when the rate at which 307 sequences go through different phases changes a lot. The variance scaling 308 allows dynamic weighting of the MSE loss based on the similarities. 309 iii) regression_huber: Huber loss between the predicted indices and ground 310 truth indices. 311 Args: 312 logits: Tensor, Pre-softmax similarity scores after cycling back to the 313 starting sequence. 314 labels: Tensor, One hot labels containing the ground truth. The index where 315 the cycle started is 1. 316 num_steps: Integer, Number of steps in the sequence embeddings. 317 steps: Tensor, step indices/frame indices of the embeddings of the shape 318 [N, T] where N is the batch size, T is the number of the timesteps. 319 seq_lens: Tensor, Lengths of the sequences from which the sampling was done. 320 This can provide additional temporal information to the alignment loss. 321 loss_type: String, This specifies the kind of regression loss function. 322 Currently supported loss functions: regression_mse, regression_mse_var, 323 regression_huber. 324 normalize_indices: Boolean, If True, normalizes indices by sequence lengths. 325 Useful for ensuring numerical instabilities don't arise as sequence 326 indices can be large numbers. 327 variance_lambda: Float, Weight of the variance of the similarity 328 predictions while cycling back. If this is high then the low variance 329 similarities are preferred by the loss while making this term low results 330 in high variance of the similarities (more uniform/random matching). 331 Returns: 332 loss: Tensor, A scalar loss calculated using a variant of regression. 333 """ 334 # Just to be safe, we stop gradients from labels as we are generating labels. 335 labels = labels.detach() 336 steps = steps.detach() 337 338 if normalize_indices: 339 float_seq_lens = seq_lens.float() 340 tile_seq_lens = ( 341 torch.tile(torch.unsqueeze(float_seq_lens, dim=1), [1, num_steps]) + 1e-7 342 ) 343 steps = steps.float() / tile_seq_lens 344 else: 345 steps = steps.float() 346 347 beta = F.softmax(logits, dim=1) 348 true_time = torch.sum(steps * labels, dim=1) 349 pred_time = torch.sum(steps * beta, dim=1) 350 351 if loss_type in ["regression_mse", "regression_mse_var"]: 352 if "var" in loss_type: 353 # Variance aware regression. 354 pred_time_tiled = torch.tile( 355 torch.unsqueeze(pred_time, dim=1), [1, num_steps] 356 ) 357 358 pred_time_variance = torch.sum( 359 ((steps - pred_time_tiled) ** 2) * beta, dim=1 360 ) 361 362 # Using log of variance as it is numerically stabler. 363 pred_time_log_var = torch.log(pred_time_variance + 1e-7) 364 squared_error = (true_time - pred_time) ** 2 365 return torch.mean( 366 torch.exp(-pred_time_log_var) * squared_error 367 + variance_lambda * pred_time_log_var 368 ) 369 370 else: 371 return torch.mean((true_time - pred_time) ** 2) 372 else: 373 raise ValueError( 374 "Unsupported regression loss %s. Supported losses are: " 375 "regression_mse, regresstion_mse_var." % loss_type 376 )
Loss function based on regressing to the correct indices. In the paper, this is called Cycle-back Regression. There are 3 variants of this loss: i) regression_mse: MSE of the predicted indices and ground truth indices. ii) regression_mse_var: MSE of the predicted indices that takes into account the variance of the similarities. This is important when the rate at which sequences go through different phases changes a lot. The variance scaling allows dynamic weighting of the MSE loss based on the similarities. iii) regression_huber: Huber loss between the predicted indices and ground truth indices. Args: logits: Tensor, Pre-softmax similarity scores after cycling back to the starting sequence. labels: Tensor, One hot labels containing the ground truth. The index where the cycle started is 1. num_steps: Integer, Number of steps in the sequence embeddings. steps: Tensor, step indices/frame indices of the embeddings of the shape [N, T] where N is the batch size, T is the number of the timesteps. seq_lens: Tensor, Lengths of the sequences from which the sampling was done. This can provide additional temporal information to the alignment loss. loss_type: String, This specifies the kind of regression loss function. Currently supported loss functions: regression_mse, regression_mse_var, regression_huber. normalize_indices: Boolean, If True, normalizes indices by sequence lengths. Useful for ensuring numerical instabilities don't arise as sequence indices can be large numbers. variance_lambda: Float, Weight of the variance of the similarity predictions while cycling back. If this is high then the low variance similarities are preferred by the loss while making this term low results in high variance of the similarities (more uniform/random matching). Returns: loss: Tensor, A scalar loss calculated using a variant of regression.