dlc2action.loss.tcc

TCC loss

Adapted from from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6#
  7# Adapted from Temporal Cycle-Consistency Learning by June01
  8# Adapted from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch
  9# Licensed under  Apache License Version 2.0, January 2004
 10#
 11""" TCC loss
 12
 13Adapted from from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch
 14
 15"""
 16
 17import torch
 18from torch.nn import functional as F
 19from torch import nn
 20
 21
 22class _TCCLoss(nn.Module):
 23    def __init__(
 24        self,
 25        loss_type: str = "regression_mse_var",
 26        variance_lambda: float = 0.001,
 27        normalize_indices: bool = True,
 28        normalize_embeddings: bool = False,
 29        similarity_type: str = "l2",
 30        num_cycles: int = 20,
 31        cycle_length: int = 2,
 32        temperature: float = 0.1,
 33        label_smoothing: float = 0.1,
 34    ):
 35        super().__init__()
 36        self.loss_type = loss_type
 37        self.variance_lambda = variance_lambda
 38        self.normalize_indices = normalize_indices
 39        self.normalize_embeddings = normalize_embeddings
 40        self.similarity_type = similarity_type
 41        self.num_cycles = num_cycles
 42        self.cycle_length = cycle_length
 43        self.temperature = temperature
 44        self.label_smoothing = label_smoothing
 45
 46    def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float:
 47        real_lens = mask.sum(-1).squeeze().cpu()
 48        if len(real_lens.shape) == 0:
 49            return torch.tensor(0)
 50        return compute_alignment_loss(
 51            predictions.transpose(-1, -2),
 52            normalize_embeddings=self.normalize_embeddings,
 53            normalize_indices=self.normalize_indices,
 54            loss_type=self.loss_type,
 55            similarity_type=self.similarity_type,
 56            num_cycles=self.num_cycles,
 57            cycle_length=self.cycle_length,
 58            temperature=self.temperature,
 59            label_smoothing=self.label_smoothing,
 60            variance_lambda=self.variance_lambda,
 61            real_lens=real_lens,
 62        )
 63
 64
 65def _align_single_cycle(
 66    cycle, embs, cycle_length, num_steps, real_len, similarity_type, temperature
 67):
 68    # choose from random frame
 69    n_idx = (torch.rand(1) * real_len).long()[0]
 70    # n_idx = torch.tensor(8).long()
 71
 72    # Create labels
 73    onehot_labels = torch.eye(num_steps)[n_idx]
 74
 75    # Choose query feats for first frame.
 76    query_feats = embs[cycle[0], n_idx : n_idx + 1]
 77    num_channels = query_feats.size(-1)
 78    for c in range(1, cycle_length + 1):
 79        candidate_feats = embs[cycle[c]]
 80        if similarity_type == "l2":
 81            mean_squared_distance = torch.sum(
 82                (query_feats.repeat([num_steps, 1]) - candidate_feats) ** 2, dim=1
 83            )
 84            similarity = -mean_squared_distance
 85        elif similarity_type == "cosine":
 86            similarity = torch.squeeze(
 87                torch.matmul(candidate_feats, query_feats.transpose(0, 1))
 88            )
 89        else:
 90            raise ValueError("similarity_type can either be l2 or cosine.")
 91
 92        similarity /= float(num_channels)
 93        similarity /= temperature
 94
 95        beta = F.softmax(similarity, dim=0).unsqueeze(1).repeat([1, num_channels])
 96        query_feats = torch.sum(beta * candidate_feats, dim=0, keepdim=True)
 97
 98    return similarity.unsqueeze(0), onehot_labels.unsqueeze(0)
 99
100
101def _align(
102    cycles,
103    embs,
104    num_steps,
105    real_lens,
106    num_cycles,
107    cycle_length,
108    similarity_type,
109    temperature,
110    batch_size,
111):
112    """Align by finding cycles in embs."""
113    logits_list = []
114    labels_list = []
115    for i in range(num_cycles):
116        if len(real_lens) == batch_size:
117            real_len = int(real_lens[cycles[i][0]])
118        else:
119            real_len = int(real_lens[cycles[i][0] // batch_size])
120        logits, labels = _align_single_cycle(
121            cycles[i],
122            embs,
123            cycle_length,
124            num_steps,
125            real_len,
126            similarity_type,
127            temperature,
128        )
129        logits_list.append(logits)
130        labels_list.append(labels)
131
132    logits = torch.cat(logits_list, dim=0)
133    labels = torch.cat(labels_list, dim=0).to(embs.device)
134
135    return logits, labels
136
137
138def gen_cycles(num_cycles, batch_size, cycle_length=2):
139    """Generates cycles for alignment.
140    Generates a batch of indices to cycle over. For example setting num_cycles=2,
141    batch_size=5, cycle_length=3 might return something like this:
142    cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which
143    the loss will be calculated. The first cycle starts at sequence 0 of the
144    batch, then we find a matching step in sequence 3 of that batch, then we
145    find matching step in sequence 4 and finally come back to sequence 0,
146    completing a cycle.
147    Args:
148    num_cycles: Integer, Number of cycles that will be matched in one pass.
149    batch_size: Integer, Number of sequences in one batch.
150    cycle_length: Integer, Length of the cycles. If we are matching between
151      2 sequences (cycle_length=2), we get cycles that look like [0,1,0].
152      This means that we go from sequence 0 to sequence 1 then back to sequence
153      0. A cycle length of 3 might look like [0, 1, 2, 0].
154    Returns:
155    cycles: Tensor, Batch indices denoting cycles that will be used for
156      calculating the alignment loss.
157    """
158    sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1])
159    sorted_idxes = sorted_idxes.view([batch_size, num_cycles])
160    cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view(
161        [num_cycles, batch_size]
162    )
163    cycles = cycles[:, :cycle_length]
164    cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1)
165
166    return cycles
167
168
169def compute_stochastic_alignment_loss(
170    embs,
171    steps,
172    seq_lens,
173    num_steps,
174    batch_size,
175    loss_type,
176    similarity_type,
177    num_cycles,
178    cycle_length,
179    temperature,
180    label_smoothing,
181    variance_lambda,
182    huber_delta,
183    normalize_indices,
184    real_lens,
185):
186
187    cycles = gen_cycles(num_cycles, batch_size, cycle_length).to(embs.device)
188    logits, labels = _align(
189        cycles=cycles,
190        embs=embs,
191        num_steps=num_steps,
192        real_lens=real_lens,
193        num_cycles=num_cycles,
194        cycle_length=cycle_length,
195        similarity_type=similarity_type,
196        temperature=temperature,
197        batch_size=batch_size,
198    )
199
200    if "regression" in loss_type:
201        steps = steps[cycles[:, 0]]
202        seq_lens = seq_lens[cycles[:, 0]]
203        loss = regression_loss(
204            logits,
205            labels,
206            num_steps,
207            steps,
208            seq_lens,
209            loss_type,
210            normalize_indices,
211            variance_lambda,
212        )
213    else:
214        raise ValueError(
215            "Unidentified loss type %s. Currently supported loss "
216            "types are: regression_mse, regression_mse_var, " % loss_type
217        )
218    return loss
219
220
221def compute_alignment_loss(
222    embs,
223    real_lens,
224    steps=None,
225    seq_lens=None,
226    normalize_embeddings=False,
227    loss_type="classification",
228    similarity_type="l2",
229    num_cycles=20,
230    cycle_length=2,
231    temperature=0.1,
232    label_smoothing=0.1,
233    variance_lambda=0.001,
234    huber_delta=0.1,
235    normalize_indices=True,
236):
237
238    # Get the number of timestemps in the sequence embeddings.
239    num_steps = embs.shape[1]
240    batch_size = embs.shape[0]
241
242    # If steps has not been provided assume sampling has been done uniformly.
243    if steps is None:
244        steps = (
245            torch.arange(0, num_steps)
246            .unsqueeze(0)
247            .repeat([batch_size, 1])
248            .to(embs.device)
249        )
250
251    # If seq_lens has not been provided assume is equal to the size of the
252    # time axis in the emebeddings.
253    if seq_lens is None:
254        seq_lens = (
255            torch.tensor(num_steps)
256            .unsqueeze(0)
257            .repeat([batch_size])
258            .int()
259            .to(embs.device)
260        )
261
262    # check if batch_size if consistent with emb etc
263    assert num_steps == steps.shape[1]
264    assert batch_size == steps.shape[0]
265
266    if normalize_embeddings:
267        embs = F.normalize(embs, dim=-1, p=2)
268
269    loss = compute_stochastic_alignment_loss(
270        embs=embs,
271        steps=steps,
272        seq_lens=seq_lens,
273        num_steps=num_steps,
274        batch_size=batch_size,
275        loss_type=loss_type,
276        similarity_type=similarity_type,
277        num_cycles=num_cycles,
278        cycle_length=cycle_length,
279        temperature=temperature,
280        label_smoothing=label_smoothing,
281        variance_lambda=variance_lambda,
282        huber_delta=huber_delta,
283        normalize_indices=normalize_indices,
284        real_lens=real_lens,
285    )
286
287    return loss
288
289
290def regression_loss(
291    logits,
292    labels,
293    num_steps,
294    steps,
295    seq_lens,
296    loss_type,
297    normalize_indices,
298    variance_lambda,
299):
300    """Loss function based on regressing to the correct indices.
301    In the paper, this is called Cycle-back Regression. There are 3 variants
302    of this loss:
303    i) regression_mse: MSE of the predicted indices and ground truth indices.
304    ii) regression_mse_var: MSE of the predicted indices that takes into account
305    the variance of the similarities. This is important when the rate at which
306    sequences go through different phases changes a lot. The variance scaling
307    allows dynamic weighting of the MSE loss based on the similarities.
308    iii) regression_huber: Huber loss between the predicted indices and ground
309    truth indices.
310    Args:
311      logits: Tensor, Pre-softmax similarity scores after cycling back to the
312        starting sequence.
313      labels: Tensor, One hot labels containing the ground truth. The index where
314        the cycle started is 1.
315      num_steps: Integer, Number of steps in the sequence embeddings.
316      steps: Tensor, step indices/frame indices of the embeddings of the shape
317        [N, T] where N is the batch size, T is the number of the timesteps.
318      seq_lens: Tensor, Lengths of the sequences from which the sampling was done.
319        This can provide additional temporal information to the alignment loss.
320      loss_type: String, This specifies the kind of regression loss function.
321        Currently supported loss functions: regression_mse, regression_mse_var,
322        regression_huber.
323      normalize_indices: Boolean, If True, normalizes indices by sequence lengths.
324        Useful for ensuring numerical instabilities don't arise as sequence
325        indices can be large numbers.
326      variance_lambda: Float, Weight of the variance of the similarity
327        predictions while cycling back. If this is high then the low variance
328        similarities are preferred by the loss while making this term low results
329        in high variance of the similarities (more uniform/random matching).
330    Returns:
331       loss: Tensor, A scalar loss calculated using a variant of regression.
332    """
333    # Just to be safe, we stop gradients from labels as we are generating labels.
334    labels = labels.detach()
335    steps = steps.detach()
336
337    if normalize_indices:
338        float_seq_lens = seq_lens.float()
339        tile_seq_lens = (
340            torch.tile(torch.unsqueeze(float_seq_lens, dim=1), [1, num_steps]) + 1e-7
341        )
342        steps = steps.float() / tile_seq_lens
343    else:
344        steps = steps.float()
345
346    beta = F.softmax(logits, dim=1)
347    true_time = torch.sum(steps * labels, dim=1)
348    pred_time = torch.sum(steps * beta, dim=1)
349
350    if loss_type in ["regression_mse", "regression_mse_var"]:
351        if "var" in loss_type:
352            # Variance aware regression.
353            pred_time_tiled = torch.tile(
354                torch.unsqueeze(pred_time, dim=1), [1, num_steps]
355            )
356
357            pred_time_variance = torch.sum(
358                ((steps - pred_time_tiled) ** 2) * beta, dim=1
359            )
360
361            # Using log of variance as it is numerically stabler.
362            pred_time_log_var = torch.log(pred_time_variance + 1e-7)
363            squared_error = (true_time - pred_time) ** 2
364            return torch.mean(
365                torch.exp(-pred_time_log_var) * squared_error
366                + variance_lambda * pred_time_log_var
367            )
368
369        else:
370            return torch.mean((true_time - pred_time) ** 2)
371    else:
372        raise ValueError(
373            "Unsupported regression loss %s. Supported losses are: "
374            "regression_mse, regresstion_mse_var." % loss_type
375        )
def gen_cycles(num_cycles, batch_size, cycle_length=2)
139def gen_cycles(num_cycles, batch_size, cycle_length=2):
140    """Generates cycles for alignment.
141    Generates a batch of indices to cycle over. For example setting num_cycles=2,
142    batch_size=5, cycle_length=3 might return something like this:
143    cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which
144    the loss will be calculated. The first cycle starts at sequence 0 of the
145    batch, then we find a matching step in sequence 3 of that batch, then we
146    find matching step in sequence 4 and finally come back to sequence 0,
147    completing a cycle.
148    Args:
149    num_cycles: Integer, Number of cycles that will be matched in one pass.
150    batch_size: Integer, Number of sequences in one batch.
151    cycle_length: Integer, Length of the cycles. If we are matching between
152      2 sequences (cycle_length=2), we get cycles that look like [0,1,0].
153      This means that we go from sequence 0 to sequence 1 then back to sequence
154      0. A cycle length of 3 might look like [0, 1, 2, 0].
155    Returns:
156    cycles: Tensor, Batch indices denoting cycles that will be used for
157      calculating the alignment loss.
158    """
159    sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1])
160    sorted_idxes = sorted_idxes.view([batch_size, num_cycles])
161    cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view(
162        [num_cycles, batch_size]
163    )
164    cycles = cycles[:, :cycle_length]
165    cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1)
166
167    return cycles

Generates cycles for alignment. Generates a batch of indices to cycle over. For example setting num_cycles=2, batch_size=5, cycle_length=3 might return something like this: cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which the loss will be calculated. The first cycle starts at sequence 0 of the batch, then we find a matching step in sequence 3 of that batch, then we find matching step in sequence 4 and finally come back to sequence 0, completing a cycle. Args: num_cycles: Integer, Number of cycles that will be matched in one pass. batch_size: Integer, Number of sequences in one batch. cycle_length: Integer, Length of the cycles. If we are matching between 2 sequences (cycle_length=2), we get cycles that look like [0,1,0]. This means that we go from sequence 0 to sequence 1 then back to sequence

  1. A cycle length of 3 might look like [0, 1, 2, 0]. Returns: cycles: Tensor, Batch indices denoting cycles that will be used for calculating the alignment loss.
def compute_stochastic_alignment_loss( embs, steps, seq_lens, num_steps, batch_size, loss_type, similarity_type, num_cycles, cycle_length, temperature, label_smoothing, variance_lambda, huber_delta, normalize_indices, real_lens)
170def compute_stochastic_alignment_loss(
171    embs,
172    steps,
173    seq_lens,
174    num_steps,
175    batch_size,
176    loss_type,
177    similarity_type,
178    num_cycles,
179    cycle_length,
180    temperature,
181    label_smoothing,
182    variance_lambda,
183    huber_delta,
184    normalize_indices,
185    real_lens,
186):
187
188    cycles = gen_cycles(num_cycles, batch_size, cycle_length).to(embs.device)
189    logits, labels = _align(
190        cycles=cycles,
191        embs=embs,
192        num_steps=num_steps,
193        real_lens=real_lens,
194        num_cycles=num_cycles,
195        cycle_length=cycle_length,
196        similarity_type=similarity_type,
197        temperature=temperature,
198        batch_size=batch_size,
199    )
200
201    if "regression" in loss_type:
202        steps = steps[cycles[:, 0]]
203        seq_lens = seq_lens[cycles[:, 0]]
204        loss = regression_loss(
205            logits,
206            labels,
207            num_steps,
208            steps,
209            seq_lens,
210            loss_type,
211            normalize_indices,
212            variance_lambda,
213        )
214    else:
215        raise ValueError(
216            "Unidentified loss type %s. Currently supported loss "
217            "types are: regression_mse, regression_mse_var, " % loss_type
218        )
219    return loss
def compute_alignment_loss( embs, real_lens, steps=None, seq_lens=None, normalize_embeddings=False, loss_type='classification', similarity_type='l2', num_cycles=20, cycle_length=2, temperature=0.1, label_smoothing=0.1, variance_lambda=0.001, huber_delta=0.1, normalize_indices=True)
222def compute_alignment_loss(
223    embs,
224    real_lens,
225    steps=None,
226    seq_lens=None,
227    normalize_embeddings=False,
228    loss_type="classification",
229    similarity_type="l2",
230    num_cycles=20,
231    cycle_length=2,
232    temperature=0.1,
233    label_smoothing=0.1,
234    variance_lambda=0.001,
235    huber_delta=0.1,
236    normalize_indices=True,
237):
238
239    # Get the number of timestemps in the sequence embeddings.
240    num_steps = embs.shape[1]
241    batch_size = embs.shape[0]
242
243    # If steps has not been provided assume sampling has been done uniformly.
244    if steps is None:
245        steps = (
246            torch.arange(0, num_steps)
247            .unsqueeze(0)
248            .repeat([batch_size, 1])
249            .to(embs.device)
250        )
251
252    # If seq_lens has not been provided assume is equal to the size of the
253    # time axis in the emebeddings.
254    if seq_lens is None:
255        seq_lens = (
256            torch.tensor(num_steps)
257            .unsqueeze(0)
258            .repeat([batch_size])
259            .int()
260            .to(embs.device)
261        )
262
263    # check if batch_size if consistent with emb etc
264    assert num_steps == steps.shape[1]
265    assert batch_size == steps.shape[0]
266
267    if normalize_embeddings:
268        embs = F.normalize(embs, dim=-1, p=2)
269
270    loss = compute_stochastic_alignment_loss(
271        embs=embs,
272        steps=steps,
273        seq_lens=seq_lens,
274        num_steps=num_steps,
275        batch_size=batch_size,
276        loss_type=loss_type,
277        similarity_type=similarity_type,
278        num_cycles=num_cycles,
279        cycle_length=cycle_length,
280        temperature=temperature,
281        label_smoothing=label_smoothing,
282        variance_lambda=variance_lambda,
283        huber_delta=huber_delta,
284        normalize_indices=normalize_indices,
285        real_lens=real_lens,
286    )
287
288    return loss
def regression_loss( logits, labels, num_steps, steps, seq_lens, loss_type, normalize_indices, variance_lambda)
291def regression_loss(
292    logits,
293    labels,
294    num_steps,
295    steps,
296    seq_lens,
297    loss_type,
298    normalize_indices,
299    variance_lambda,
300):
301    """Loss function based on regressing to the correct indices.
302    In the paper, this is called Cycle-back Regression. There are 3 variants
303    of this loss:
304    i) regression_mse: MSE of the predicted indices and ground truth indices.
305    ii) regression_mse_var: MSE of the predicted indices that takes into account
306    the variance of the similarities. This is important when the rate at which
307    sequences go through different phases changes a lot. The variance scaling
308    allows dynamic weighting of the MSE loss based on the similarities.
309    iii) regression_huber: Huber loss between the predicted indices and ground
310    truth indices.
311    Args:
312      logits: Tensor, Pre-softmax similarity scores after cycling back to the
313        starting sequence.
314      labels: Tensor, One hot labels containing the ground truth. The index where
315        the cycle started is 1.
316      num_steps: Integer, Number of steps in the sequence embeddings.
317      steps: Tensor, step indices/frame indices of the embeddings of the shape
318        [N, T] where N is the batch size, T is the number of the timesteps.
319      seq_lens: Tensor, Lengths of the sequences from which the sampling was done.
320        This can provide additional temporal information to the alignment loss.
321      loss_type: String, This specifies the kind of regression loss function.
322        Currently supported loss functions: regression_mse, regression_mse_var,
323        regression_huber.
324      normalize_indices: Boolean, If True, normalizes indices by sequence lengths.
325        Useful for ensuring numerical instabilities don't arise as sequence
326        indices can be large numbers.
327      variance_lambda: Float, Weight of the variance of the similarity
328        predictions while cycling back. If this is high then the low variance
329        similarities are preferred by the loss while making this term low results
330        in high variance of the similarities (more uniform/random matching).
331    Returns:
332       loss: Tensor, A scalar loss calculated using a variant of regression.
333    """
334    # Just to be safe, we stop gradients from labels as we are generating labels.
335    labels = labels.detach()
336    steps = steps.detach()
337
338    if normalize_indices:
339        float_seq_lens = seq_lens.float()
340        tile_seq_lens = (
341            torch.tile(torch.unsqueeze(float_seq_lens, dim=1), [1, num_steps]) + 1e-7
342        )
343        steps = steps.float() / tile_seq_lens
344    else:
345        steps = steps.float()
346
347    beta = F.softmax(logits, dim=1)
348    true_time = torch.sum(steps * labels, dim=1)
349    pred_time = torch.sum(steps * beta, dim=1)
350
351    if loss_type in ["regression_mse", "regression_mse_var"]:
352        if "var" in loss_type:
353            # Variance aware regression.
354            pred_time_tiled = torch.tile(
355                torch.unsqueeze(pred_time, dim=1), [1, num_steps]
356            )
357
358            pred_time_variance = torch.sum(
359                ((steps - pred_time_tiled) ** 2) * beta, dim=1
360            )
361
362            # Using log of variance as it is numerically stabler.
363            pred_time_log_var = torch.log(pred_time_variance + 1e-7)
364            squared_error = (true_time - pred_time) ** 2
365            return torch.mean(
366                torch.exp(-pred_time_log_var) * squared_error
367                + variance_lambda * pred_time_log_var
368            )
369
370        else:
371            return torch.mean((true_time - pred_time) ** 2)
372    else:
373        raise ValueError(
374            "Unsupported regression loss %s. Supported losses are: "
375            "regression_mse, regresstion_mse_var." % loss_type
376        )

Loss function based on regressing to the correct indices. In the paper, this is called Cycle-back Regression. There are 3 variants of this loss: i) regression_mse: MSE of the predicted indices and ground truth indices. ii) regression_mse_var: MSE of the predicted indices that takes into account the variance of the similarities. This is important when the rate at which sequences go through different phases changes a lot. The variance scaling allows dynamic weighting of the MSE loss based on the similarities. iii) regression_huber: Huber loss between the predicted indices and ground truth indices. Args: logits: Tensor, Pre-softmax similarity scores after cycling back to the starting sequence. labels: Tensor, One hot labels containing the ground truth. The index where the cycle started is 1. num_steps: Integer, Number of steps in the sequence embeddings. steps: Tensor, step indices/frame indices of the embeddings of the shape [N, T] where N is the batch size, T is the number of the timesteps. seq_lens: Tensor, Lengths of the sequences from which the sampling was done. This can provide additional temporal information to the alignment loss. loss_type: String, This specifies the kind of regression loss function. Currently supported loss functions: regression_mse, regression_mse_var, regression_huber. normalize_indices: Boolean, If True, normalizes indices by sequence lengths. Useful for ensuring numerical instabilities don't arise as sequence indices can be large numbers. variance_lambda: Float, Weight of the variance of the similarity predictions while cycling back. If this is high then the low variance similarities are preferred by the loss while making this term low results in high variance of the similarities (more uniform/random matching). Returns: loss: Tensor, A scalar loss calculated using a variant of regression.