dlc2action.loss.tcc

TCC loss (from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch).

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from Temporal Cycle-Consistency by June01
  7# Original work Copyright (c) 2004 June01
  8# Source: https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch
  9# Originally licensed under Apache License Version 2.0, January 2004
 10# Combined work licensed under GNU AGPLv3
 11#
 12"""TCC loss (from https://github.com/June01/tcc_Temporal_Cycle_Consistency_Loss.pytorch)."""
 13
 14import torch
 15from torch import nn
 16from torch.nn import functional as F
 17
 18
 19class TCCLoss(nn.Module):
 20    """Temporal Cycle Consistency Loss."""
 21
 22    def __init__(
 23        self,
 24        loss_type: str = "regression_mse_var",
 25        variance_lambda: float = 0.001,
 26        normalize_indices: bool = True,
 27        normalize_embeddings: bool = False,
 28        similarity_type: str = "l2",
 29        num_cycles: int = 20,
 30        cycle_length: int = 2,
 31        temperature: float = 0.1,
 32        label_smoothing: float = 0.1,
 33    ):
 34        """Initialize the loss."""
 35        super().__init__()
 36        self.loss_type = loss_type
 37        self.variance_lambda = variance_lambda
 38        self.normalize_indices = normalize_indices
 39        self.normalize_embeddings = normalize_embeddings
 40        self.similarity_type = similarity_type
 41        self.num_cycles = num_cycles
 42        self.cycle_length = cycle_length
 43        self.temperature = temperature
 44        self.label_smoothing = label_smoothing
 45
 46    def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float:
 47        """Forward pass."""
 48        real_lens = mask.sum(-1).squeeze().cpu()
 49        if len(real_lens.shape) == 0:
 50            return torch.tensor(0)
 51        return compute_alignment_loss(
 52            predictions.transpose(-1, -2),
 53            normalize_embeddings=self.normalize_embeddings,
 54            normalize_indices=self.normalize_indices,
 55            loss_type=self.loss_type,
 56            similarity_type=self.similarity_type,
 57            num_cycles=self.num_cycles,
 58            cycle_length=self.cycle_length,
 59            temperature=self.temperature,
 60            label_smoothing=self.label_smoothing,
 61            variance_lambda=self.variance_lambda,
 62            real_lens=real_lens,
 63        )
 64
 65
 66def _align_single_cycle(
 67    cycle, embs, cycle_length, num_steps, real_len, similarity_type, temperature
 68):
 69    # choose from random frame
 70    n_idx = (torch.rand(1) * real_len).long()[0]
 71    # n_idx = torch.tensor(8).long()
 72
 73    # Create labels
 74    onehot_labels = torch.eye(num_steps)[n_idx]
 75
 76    # Choose query feats for first frame.
 77    query_feats = embs[cycle[0], n_idx : n_idx + 1]
 78    num_channels = query_feats.size(-1)
 79    for c in range(1, cycle_length + 1):
 80        candidate_feats = embs[cycle[c]]
 81        if similarity_type == "l2":
 82            mean_squared_distance = torch.sum(
 83                (query_feats.repeat([num_steps, 1]) - candidate_feats) ** 2, dim=1
 84            )
 85            similarity = -mean_squared_distance
 86        elif similarity_type == "cosine":
 87            similarity = torch.squeeze(
 88                torch.matmul(candidate_feats, query_feats.transpose(0, 1))
 89            )
 90        else:
 91            raise ValueError("similarity_type can either be l2 or cosine.")
 92
 93        similarity /= float(num_channels)
 94        similarity /= temperature
 95
 96        beta = F.softmax(similarity, dim=0).unsqueeze(1).repeat([1, num_channels])
 97        query_feats = torch.sum(beta * candidate_feats, dim=0, keepdim=True)
 98
 99    return similarity.unsqueeze(0), onehot_labels.unsqueeze(0)
100
101
102def _align(
103    cycles,
104    embs,
105    num_steps,
106    real_lens,
107    num_cycles,
108    cycle_length,
109    similarity_type,
110    temperature,
111    batch_size,
112):
113    """Align by finding cycles in embs."""
114    logits_list = []
115    labels_list = []
116    for i in range(num_cycles):
117        if len(real_lens) == batch_size:
118            real_len = int(real_lens[cycles[i][0]])
119        else:
120            real_len = int(real_lens[cycles[i][0] // batch_size])
121        logits, labels = _align_single_cycle(
122            cycles[i],
123            embs,
124            cycle_length,
125            num_steps,
126            real_len,
127            similarity_type,
128            temperature,
129        )
130        logits_list.append(logits)
131        labels_list.append(labels)
132
133    logits = torch.cat(logits_list, dim=0)
134    labels = torch.cat(labels_list, dim=0).to(embs.device)
135
136    return logits, labels
137
138
139def gen_cycles(num_cycles, batch_size, cycle_length=2):
140    """Generate cycles for alignment.
141
142    Generates a batch of indices to cycle over. For example setting num_cycles=2,
143    batch_size=5, cycle_length=3 might return something like this:
144    cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which
145    the loss will be calculated. The first cycle starts at sequence 0 of the
146    batch, then we find a matching step in sequence 3 of that batch, then we
147    find matching step in sequence 4 and finally come back to sequence 0,
148    completing a cycle.
149
150    Parameters
151    ----------
152    num_cycles : int
153        Number of cycles that will be matched in one pass.
154    batch_size : int
155        Number of sequences in one batch.
156    cycle_length : int
157        Length of the cycles. If we are matching between 2 sequences (cycle_length=2),
158        we get cycles that look like [0,1,0]. This means that we go from sequence 0
159        to sequence 1 then back to sequence 0. A cycle length of 3 might look like
160        [0, 1, 2, 0].
161
162    Returns
163    -------
164    cycles : torch.Tensor
165        Batch indices denoting cycles that will be used for calculating the alignment loss.
166
167    """
168    sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1])
169    sorted_idxes = sorted_idxes.view([batch_size, num_cycles])
170    cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view(
171        [num_cycles, batch_size]
172    )
173    cycles = cycles[:, :cycle_length]
174    cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1)
175
176    return cycles
177
178
179def compute_stochastic_alignment_loss(
180    embs,
181    steps,
182    seq_lens,
183    num_steps,
184    batch_size,
185    loss_type,
186    similarity_type,
187    num_cycles,
188    cycle_length,
189    temperature,
190    variance_lambda,
191    normalize_indices,
192    real_lens,
193):
194    """Compute stochastic alignment loss.
195
196    Parameters
197    ----------
198    embs : torch.Tensor
199        Embeddings of shape (batch_size, num_steps, num_channels).
200    steps : torch.Tensor
201        Steps of shape (batch_size, num_steps).
202    seq_lens : torch.Tensor
203        Sequence lengths of shape (batch_size,).
204    num_steps : int
205        Number of steps in the sequence.
206    batch_size : int
207        Batch size.
208    loss_type : str
209        Type of loss to use. Can be either "regression" or "classification".
210    similarity_type : str
211        Type of similarity to use. Can be either "l2" or "cosine".
212    num_cycles : int
213        Number of cycles to use for alignment.
214    cycle_length : int
215        Length of cycles to use for alignment.
216    temperature : float
217        Temperature to use for alignment.
218    variance_lambda : float
219        Lambda to use for variance regularization.
220    normalize_indices : bool
221        Whether to normalize indices.
222    real_lens : torch.Tensor
223        Real lengths of sequences.
224
225    Returns
226    -------
227    loss : torch.Tensor
228        Alignment loss.
229
230    """
231    cycles = gen_cycles(num_cycles, batch_size, cycle_length).to(embs.device)
232    logits, labels = _align(
233        cycles=cycles,
234        embs=embs,
235        num_steps=num_steps,
236        real_lens=real_lens,
237        num_cycles=num_cycles,
238        cycle_length=cycle_length,
239        similarity_type=similarity_type,
240        temperature=temperature,
241        batch_size=batch_size,
242    )
243
244    if "regression" in loss_type:
245        steps = steps[cycles[:, 0]]
246        seq_lens = seq_lens[cycles[:, 0]]
247        loss = regression_loss(
248            logits,
249            labels,
250            num_steps,
251            steps,
252            seq_lens,
253            loss_type,
254            normalize_indices,
255            variance_lambda,
256        )
257    else:
258        raise ValueError(
259            "Unidentified loss type %s. Currently supported loss "
260            "types are: regression_mse, regression_mse_var, " % loss_type
261        )
262    return loss
263
264
265def compute_alignment_loss(
266    embs,
267    real_lens,
268    steps=None,
269    seq_lens=None,
270    normalize_embeddings=False,
271    loss_type="classification",
272    similarity_type="l2",
273    num_cycles=20,
274    cycle_length=2,
275    temperature=0.1,
276    label_smoothing=0.1,
277    variance_lambda=0.001,
278    huber_delta=0.1,
279    normalize_indices=True,
280):
281    """Compute alignment loss.
282
283    Parameters
284    ----------
285    embs : torch.Tensor
286        Sequence embeddings of shape (batch_size, num_steps, emb_dim)
287    real_lens : torch.Tensor
288        Length of each sequence in the batch of shape (batch_size,)
289    steps : torch.Tensor, optional
290        Step indices of shape (batch_size, num_steps), by default None
291    seq_lens : torch.Tensor, optional
292        Length of each sequence in the batch of shape (batch_size,), by default None
293    normalize_embeddings : bool, optional
294        Whether to normalize embeddings, by default False
295    loss_type : str, default "classification"
296        Type of loss to use, by default "classification"
297    similarity_type : str, default "l2"
298        Type of similarity to use
299    num_cycles : int, default 20
300        Number of cycles to use
301    cycle_length : int, default 2
302        Length of cycles to use
303    temperature : float, default 0.1
304        Temperature to use for softmax
305    label_smoothing : float, default 0.1
306        Label smoothing to use
307    variance_lambda : float, default 0.001
308        Variance lambda to use
309    huber_delta : float, default 0.1
310        Huber delta to use
311    normalize_indices : bool, default True
312        Whether to normalize indices
313
314    """
315    # Get the number of timestamps in the sequence embeddings.
316    num_steps = embs.shape[1]
317    batch_size = embs.shape[0]
318
319    # If steps has not been provided assume sampling has been done uniformly.
320    if steps is None:
321        steps = (
322            torch.arange(0, num_steps)
323            .unsqueeze(0)
324            .repeat([batch_size, 1])
325            .to(embs.device)
326        )
327
328    # If seq_lens has not been provided assume is equal to the size of the
329    # time axis in the emebeddings.
330    if seq_lens is None:
331        seq_lens = (
332            torch.tensor(num_steps)
333            .unsqueeze(0)
334            .repeat([batch_size])
335            .int()
336            .to(embs.device)
337        )
338
339    # check if batch_size if consistent with emb etc
340    assert num_steps == steps.shape[1]
341    assert batch_size == steps.shape[0]
342
343    if normalize_embeddings:
344        embs = F.normalize(embs, dim=-1, p=2)
345
346    loss = compute_stochastic_alignment_loss(
347        embs=embs,
348        steps=steps,
349        seq_lens=seq_lens,
350        num_steps=num_steps,
351        batch_size=batch_size,
352        loss_type=loss_type,
353        similarity_type=similarity_type,
354        num_cycles=num_cycles,
355        cycle_length=cycle_length,
356        temperature=temperature,
357        # label_smoothing=label_smoothing,
358        variance_lambda=variance_lambda,
359        # huber_delta=huber_delta,
360        normalize_indices=normalize_indices,
361        real_lens=real_lens,
362    )
363
364    return loss
365
366
367def regression_loss(
368    logits,
369    labels,
370    num_steps,
371    steps,
372    seq_lens,
373    loss_type,
374    normalize_indices,
375    variance_lambda,
376):
377    """Loss function based on regressing to the correct indices.
378
379    In the paper, this is called Cycle-back Regression. There are 3 variants
380    of this loss:
381    i) regression_mse: MSE of the predicted indices and ground truth indices.
382    ii) regression_mse_var: MSE of the predicted indices that takes into account
383    the variance of the similarities. This is important when the rate at which
384    sequences go through different phases changes a lot. The variance scaling
385    allows dynamic weighting of the MSE loss based on the similarities.
386    iii) regression_huber: Huber loss between the predicted indices and ground
387    truth indices.
388
389    Parameters
390    ----------
391    logits : torch.Tensor
392        Pre-softmax similarity scores after cycling back to the starting sequence
393        of shape (batch_size, num_steps)
394    labels : torch.Tensor
395        One hot labels containing the ground truth. The index where the cycle
396        started is 1. Shape (batch_size, num_steps)
397    num_steps : int
398        Number of steps in the sequence embeddings
399    steps : torch.Tensor
400        Step indices/frame indices of the embeddings of the shape (batch_size, num_steps)
401    seq_lens : torch.Tensor
402        Lengths of the sequences from which the sampling was done. This can
403        provide additional temporal information to the alignment loss.
404    loss_type : str
405        This specifies the kind of regression loss function. Currently supported
406        loss functions: regression_mse, regression_mse_var, regression_huber.
407    normalize_indices : bool
408        If True, normalizes indices by sequence lengths. Useful for ensuring
409        numerical instabilities don't arise as sequence indices can be large
410        numbers.
411    variance_lambda : float
412        Weight of the variance of the similarity predictions while cycling back
413        to the starting sequence.
414
415    Returns
416    -------
417    loss : torch.Tensor
418        A scalar loss calculated using a variant of regression.
419
420    """
421    # Just to be safe, we stop gradients from labels as we are generating labels.
422    labels = labels.detach()
423    steps = steps.detach()
424
425    if normalize_indices:
426        float_seq_lens = seq_lens.float()
427        tile_seq_lens = (
428            torch.tile(torch.unsqueeze(float_seq_lens, dim=1), [1, num_steps]) + 1e-7
429        )
430        steps = steps.float() / tile_seq_lens
431    else:
432        steps = steps.float()
433
434    beta = F.softmax(logits, dim=1)
435    true_time = torch.sum(steps * labels, dim=1)
436    pred_time = torch.sum(steps * beta, dim=1)
437
438    if loss_type in ["regression_mse", "regression_mse_var"]:
439        if "var" in loss_type:
440            # Variance aware regression.
441            pred_time_tiled = torch.tile(
442                torch.unsqueeze(pred_time, dim=1), [1, num_steps]
443            )
444
445            pred_time_variance = torch.sum(
446                ((steps - pred_time_tiled) ** 2) * beta, dim=1
447            )
448
449            # Using log of variance as it is numerically stabler.
450            pred_time_log_var = torch.log(pred_time_variance + 1e-7)
451            squared_error = (true_time - pred_time) ** 2
452            return torch.mean(
453                torch.exp(-pred_time_log_var) * squared_error
454                + variance_lambda * pred_time_log_var
455            )
456
457        else:
458            return torch.mean((true_time - pred_time) ** 2)
459    else:
460        raise ValueError(
461            "Unsupported regression loss %s. Supported losses are: "
462            "regression_mse, regresstion_mse_var." % loss_type
463        )
class TCCLoss(torch.nn.modules.module.Module):
20class TCCLoss(nn.Module):
21    """Temporal Cycle Consistency Loss."""
22
23    def __init__(
24        self,
25        loss_type: str = "regression_mse_var",
26        variance_lambda: float = 0.001,
27        normalize_indices: bool = True,
28        normalize_embeddings: bool = False,
29        similarity_type: str = "l2",
30        num_cycles: int = 20,
31        cycle_length: int = 2,
32        temperature: float = 0.1,
33        label_smoothing: float = 0.1,
34    ):
35        """Initialize the loss."""
36        super().__init__()
37        self.loss_type = loss_type
38        self.variance_lambda = variance_lambda
39        self.normalize_indices = normalize_indices
40        self.normalize_embeddings = normalize_embeddings
41        self.similarity_type = similarity_type
42        self.num_cycles = num_cycles
43        self.cycle_length = cycle_length
44        self.temperature = temperature
45        self.label_smoothing = label_smoothing
46
47    def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float:
48        """Forward pass."""
49        real_lens = mask.sum(-1).squeeze().cpu()
50        if len(real_lens.shape) == 0:
51            return torch.tensor(0)
52        return compute_alignment_loss(
53            predictions.transpose(-1, -2),
54            normalize_embeddings=self.normalize_embeddings,
55            normalize_indices=self.normalize_indices,
56            loss_type=self.loss_type,
57            similarity_type=self.similarity_type,
58            num_cycles=self.num_cycles,
59            cycle_length=self.cycle_length,
60            temperature=self.temperature,
61            label_smoothing=self.label_smoothing,
62            variance_lambda=self.variance_lambda,
63            real_lens=real_lens,
64        )

Temporal Cycle Consistency Loss.

TCCLoss( loss_type: str = 'regression_mse_var', variance_lambda: float = 0.001, normalize_indices: bool = True, normalize_embeddings: bool = False, similarity_type: str = 'l2', num_cycles: int = 20, cycle_length: int = 2, temperature: float = 0.1, label_smoothing: float = 0.1)
23    def __init__(
24        self,
25        loss_type: str = "regression_mse_var",
26        variance_lambda: float = 0.001,
27        normalize_indices: bool = True,
28        normalize_embeddings: bool = False,
29        similarity_type: str = "l2",
30        num_cycles: int = 20,
31        cycle_length: int = 2,
32        temperature: float = 0.1,
33        label_smoothing: float = 0.1,
34    ):
35        """Initialize the loss."""
36        super().__init__()
37        self.loss_type = loss_type
38        self.variance_lambda = variance_lambda
39        self.normalize_indices = normalize_indices
40        self.normalize_embeddings = normalize_embeddings
41        self.similarity_type = similarity_type
42        self.num_cycles = num_cycles
43        self.cycle_length = cycle_length
44        self.temperature = temperature
45        self.label_smoothing = label_smoothing

Initialize the loss.

loss_type
variance_lambda
normalize_indices
normalize_embeddings
similarity_type
num_cycles
cycle_length
temperature
label_smoothing
def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float:
47    def forward(self, predictions: torch.Tensor, mask: torch.Tensor) -> float:
48        """Forward pass."""
49        real_lens = mask.sum(-1).squeeze().cpu()
50        if len(real_lens.shape) == 0:
51            return torch.tensor(0)
52        return compute_alignment_loss(
53            predictions.transpose(-1, -2),
54            normalize_embeddings=self.normalize_embeddings,
55            normalize_indices=self.normalize_indices,
56            loss_type=self.loss_type,
57            similarity_type=self.similarity_type,
58            num_cycles=self.num_cycles,
59            cycle_length=self.cycle_length,
60            temperature=self.temperature,
61            label_smoothing=self.label_smoothing,
62            variance_lambda=self.variance_lambda,
63            real_lens=real_lens,
64        )

Forward pass.

def gen_cycles(num_cycles, batch_size, cycle_length=2):
140def gen_cycles(num_cycles, batch_size, cycle_length=2):
141    """Generate cycles for alignment.
142
143    Generates a batch of indices to cycle over. For example setting num_cycles=2,
144    batch_size=5, cycle_length=3 might return something like this:
145    cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which
146    the loss will be calculated. The first cycle starts at sequence 0 of the
147    batch, then we find a matching step in sequence 3 of that batch, then we
148    find matching step in sequence 4 and finally come back to sequence 0,
149    completing a cycle.
150
151    Parameters
152    ----------
153    num_cycles : int
154        Number of cycles that will be matched in one pass.
155    batch_size : int
156        Number of sequences in one batch.
157    cycle_length : int
158        Length of the cycles. If we are matching between 2 sequences (cycle_length=2),
159        we get cycles that look like [0,1,0]. This means that we go from sequence 0
160        to sequence 1 then back to sequence 0. A cycle length of 3 might look like
161        [0, 1, 2, 0].
162
163    Returns
164    -------
165    cycles : torch.Tensor
166        Batch indices denoting cycles that will be used for calculating the alignment loss.
167
168    """
169    sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1])
170    sorted_idxes = sorted_idxes.view([batch_size, num_cycles])
171    cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view(
172        [num_cycles, batch_size]
173    )
174    cycles = cycles[:, :cycle_length]
175    cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1)
176
177    return cycles

Generate cycles for alignment.

Generates a batch of indices to cycle over. For example setting num_cycles=2, batch_size=5, cycle_length=3 might return something like this: cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which the loss will be calculated. The first cycle starts at sequence 0 of the batch, then we find a matching step in sequence 3 of that batch, then we find matching step in sequence 4 and finally come back to sequence 0, completing a cycle.

Parameters

num_cycles : int Number of cycles that will be matched in one pass. batch_size : int Number of sequences in one batch. cycle_length : int Length of the cycles. If we are matching between 2 sequences (cycle_length=2), we get cycles that look like [0,1,0]. This means that we go from sequence 0 to sequence 1 then back to sequence 0. A cycle length of 3 might look like [0, 1, 2, 0].

Returns

cycles : torch.Tensor Batch indices denoting cycles that will be used for calculating the alignment loss.

def compute_stochastic_alignment_loss( embs, steps, seq_lens, num_steps, batch_size, loss_type, similarity_type, num_cycles, cycle_length, temperature, variance_lambda, normalize_indices, real_lens):
180def compute_stochastic_alignment_loss(
181    embs,
182    steps,
183    seq_lens,
184    num_steps,
185    batch_size,
186    loss_type,
187    similarity_type,
188    num_cycles,
189    cycle_length,
190    temperature,
191    variance_lambda,
192    normalize_indices,
193    real_lens,
194):
195    """Compute stochastic alignment loss.
196
197    Parameters
198    ----------
199    embs : torch.Tensor
200        Embeddings of shape (batch_size, num_steps, num_channels).
201    steps : torch.Tensor
202        Steps of shape (batch_size, num_steps).
203    seq_lens : torch.Tensor
204        Sequence lengths of shape (batch_size,).
205    num_steps : int
206        Number of steps in the sequence.
207    batch_size : int
208        Batch size.
209    loss_type : str
210        Type of loss to use. Can be either "regression" or "classification".
211    similarity_type : str
212        Type of similarity to use. Can be either "l2" or "cosine".
213    num_cycles : int
214        Number of cycles to use for alignment.
215    cycle_length : int
216        Length of cycles to use for alignment.
217    temperature : float
218        Temperature to use for alignment.
219    variance_lambda : float
220        Lambda to use for variance regularization.
221    normalize_indices : bool
222        Whether to normalize indices.
223    real_lens : torch.Tensor
224        Real lengths of sequences.
225
226    Returns
227    -------
228    loss : torch.Tensor
229        Alignment loss.
230
231    """
232    cycles = gen_cycles(num_cycles, batch_size, cycle_length).to(embs.device)
233    logits, labels = _align(
234        cycles=cycles,
235        embs=embs,
236        num_steps=num_steps,
237        real_lens=real_lens,
238        num_cycles=num_cycles,
239        cycle_length=cycle_length,
240        similarity_type=similarity_type,
241        temperature=temperature,
242        batch_size=batch_size,
243    )
244
245    if "regression" in loss_type:
246        steps = steps[cycles[:, 0]]
247        seq_lens = seq_lens[cycles[:, 0]]
248        loss = regression_loss(
249            logits,
250            labels,
251            num_steps,
252            steps,
253            seq_lens,
254            loss_type,
255            normalize_indices,
256            variance_lambda,
257        )
258    else:
259        raise ValueError(
260            "Unidentified loss type %s. Currently supported loss "
261            "types are: regression_mse, regression_mse_var, " % loss_type
262        )
263    return loss

Compute stochastic alignment loss.

Parameters

embs : torch.Tensor Embeddings of shape (batch_size, num_steps, num_channels). steps : torch.Tensor Steps of shape (batch_size, num_steps). seq_lens : torch.Tensor Sequence lengths of shape (batch_size,). num_steps : int Number of steps in the sequence. batch_size : int Batch size. loss_type : str Type of loss to use. Can be either "regression" or "classification". similarity_type : str Type of similarity to use. Can be either "l2" or "cosine". num_cycles : int Number of cycles to use for alignment. cycle_length : int Length of cycles to use for alignment. temperature : float Temperature to use for alignment. variance_lambda : float Lambda to use for variance regularization. normalize_indices : bool Whether to normalize indices. real_lens : torch.Tensor Real lengths of sequences.

Returns

loss : torch.Tensor Alignment loss.

def compute_alignment_loss( embs, real_lens, steps=None, seq_lens=None, normalize_embeddings=False, loss_type='classification', similarity_type='l2', num_cycles=20, cycle_length=2, temperature=0.1, label_smoothing=0.1, variance_lambda=0.001, huber_delta=0.1, normalize_indices=True):
266def compute_alignment_loss(
267    embs,
268    real_lens,
269    steps=None,
270    seq_lens=None,
271    normalize_embeddings=False,
272    loss_type="classification",
273    similarity_type="l2",
274    num_cycles=20,
275    cycle_length=2,
276    temperature=0.1,
277    label_smoothing=0.1,
278    variance_lambda=0.001,
279    huber_delta=0.1,
280    normalize_indices=True,
281):
282    """Compute alignment loss.
283
284    Parameters
285    ----------
286    embs : torch.Tensor
287        Sequence embeddings of shape (batch_size, num_steps, emb_dim)
288    real_lens : torch.Tensor
289        Length of each sequence in the batch of shape (batch_size,)
290    steps : torch.Tensor, optional
291        Step indices of shape (batch_size, num_steps), by default None
292    seq_lens : torch.Tensor, optional
293        Length of each sequence in the batch of shape (batch_size,), by default None
294    normalize_embeddings : bool, optional
295        Whether to normalize embeddings, by default False
296    loss_type : str, default "classification"
297        Type of loss to use, by default "classification"
298    similarity_type : str, default "l2"
299        Type of similarity to use
300    num_cycles : int, default 20
301        Number of cycles to use
302    cycle_length : int, default 2
303        Length of cycles to use
304    temperature : float, default 0.1
305        Temperature to use for softmax
306    label_smoothing : float, default 0.1
307        Label smoothing to use
308    variance_lambda : float, default 0.001
309        Variance lambda to use
310    huber_delta : float, default 0.1
311        Huber delta to use
312    normalize_indices : bool, default True
313        Whether to normalize indices
314
315    """
316    # Get the number of timestamps in the sequence embeddings.
317    num_steps = embs.shape[1]
318    batch_size = embs.shape[0]
319
320    # If steps has not been provided assume sampling has been done uniformly.
321    if steps is None:
322        steps = (
323            torch.arange(0, num_steps)
324            .unsqueeze(0)
325            .repeat([batch_size, 1])
326            .to(embs.device)
327        )
328
329    # If seq_lens has not been provided assume is equal to the size of the
330    # time axis in the emebeddings.
331    if seq_lens is None:
332        seq_lens = (
333            torch.tensor(num_steps)
334            .unsqueeze(0)
335            .repeat([batch_size])
336            .int()
337            .to(embs.device)
338        )
339
340    # check if batch_size if consistent with emb etc
341    assert num_steps == steps.shape[1]
342    assert batch_size == steps.shape[0]
343
344    if normalize_embeddings:
345        embs = F.normalize(embs, dim=-1, p=2)
346
347    loss = compute_stochastic_alignment_loss(
348        embs=embs,
349        steps=steps,
350        seq_lens=seq_lens,
351        num_steps=num_steps,
352        batch_size=batch_size,
353        loss_type=loss_type,
354        similarity_type=similarity_type,
355        num_cycles=num_cycles,
356        cycle_length=cycle_length,
357        temperature=temperature,
358        # label_smoothing=label_smoothing,
359        variance_lambda=variance_lambda,
360        # huber_delta=huber_delta,
361        normalize_indices=normalize_indices,
362        real_lens=real_lens,
363    )
364
365    return loss

Compute alignment loss.

Parameters

embs : torch.Tensor Sequence embeddings of shape (batch_size, num_steps, emb_dim) real_lens : torch.Tensor Length of each sequence in the batch of shape (batch_size,) steps : torch.Tensor, optional Step indices of shape (batch_size, num_steps), by default None seq_lens : torch.Tensor, optional Length of each sequence in the batch of shape (batch_size,), by default None normalize_embeddings : bool, optional Whether to normalize embeddings, by default False loss_type : str, default "classification" Type of loss to use, by default "classification" similarity_type : str, default "l2" Type of similarity to use num_cycles : int, default 20 Number of cycles to use cycle_length : int, default 2 Length of cycles to use temperature : float, default 0.1 Temperature to use for softmax label_smoothing : float, default 0.1 Label smoothing to use variance_lambda : float, default 0.001 Variance lambda to use huber_delta : float, default 0.1 Huber delta to use normalize_indices : bool, default True Whether to normalize indices

def regression_loss( logits, labels, num_steps, steps, seq_lens, loss_type, normalize_indices, variance_lambda):
368def regression_loss(
369    logits,
370    labels,
371    num_steps,
372    steps,
373    seq_lens,
374    loss_type,
375    normalize_indices,
376    variance_lambda,
377):
378    """Loss function based on regressing to the correct indices.
379
380    In the paper, this is called Cycle-back Regression. There are 3 variants
381    of this loss:
382    i) regression_mse: MSE of the predicted indices and ground truth indices.
383    ii) regression_mse_var: MSE of the predicted indices that takes into account
384    the variance of the similarities. This is important when the rate at which
385    sequences go through different phases changes a lot. The variance scaling
386    allows dynamic weighting of the MSE loss based on the similarities.
387    iii) regression_huber: Huber loss between the predicted indices and ground
388    truth indices.
389
390    Parameters
391    ----------
392    logits : torch.Tensor
393        Pre-softmax similarity scores after cycling back to the starting sequence
394        of shape (batch_size, num_steps)
395    labels : torch.Tensor
396        One hot labels containing the ground truth. The index where the cycle
397        started is 1. Shape (batch_size, num_steps)
398    num_steps : int
399        Number of steps in the sequence embeddings
400    steps : torch.Tensor
401        Step indices/frame indices of the embeddings of the shape (batch_size, num_steps)
402    seq_lens : torch.Tensor
403        Lengths of the sequences from which the sampling was done. This can
404        provide additional temporal information to the alignment loss.
405    loss_type : str
406        This specifies the kind of regression loss function. Currently supported
407        loss functions: regression_mse, regression_mse_var, regression_huber.
408    normalize_indices : bool
409        If True, normalizes indices by sequence lengths. Useful for ensuring
410        numerical instabilities don't arise as sequence indices can be large
411        numbers.
412    variance_lambda : float
413        Weight of the variance of the similarity predictions while cycling back
414        to the starting sequence.
415
416    Returns
417    -------
418    loss : torch.Tensor
419        A scalar loss calculated using a variant of regression.
420
421    """
422    # Just to be safe, we stop gradients from labels as we are generating labels.
423    labels = labels.detach()
424    steps = steps.detach()
425
426    if normalize_indices:
427        float_seq_lens = seq_lens.float()
428        tile_seq_lens = (
429            torch.tile(torch.unsqueeze(float_seq_lens, dim=1), [1, num_steps]) + 1e-7
430        )
431        steps = steps.float() / tile_seq_lens
432    else:
433        steps = steps.float()
434
435    beta = F.softmax(logits, dim=1)
436    true_time = torch.sum(steps * labels, dim=1)
437    pred_time = torch.sum(steps * beta, dim=1)
438
439    if loss_type in ["regression_mse", "regression_mse_var"]:
440        if "var" in loss_type:
441            # Variance aware regression.
442            pred_time_tiled = torch.tile(
443                torch.unsqueeze(pred_time, dim=1), [1, num_steps]
444            )
445
446            pred_time_variance = torch.sum(
447                ((steps - pred_time_tiled) ** 2) * beta, dim=1
448            )
449
450            # Using log of variance as it is numerically stabler.
451            pred_time_log_var = torch.log(pred_time_variance + 1e-7)
452            squared_error = (true_time - pred_time) ** 2
453            return torch.mean(
454                torch.exp(-pred_time_log_var) * squared_error
455                + variance_lambda * pred_time_log_var
456            )
457
458        else:
459            return torch.mean((true_time - pred_time) ** 2)
460    else:
461        raise ValueError(
462            "Unsupported regression loss %s. Supported losses are: "
463            "regression_mse, regresstion_mse_var." % loss_type
464        )

Loss function based on regressing to the correct indices.

In the paper, this is called Cycle-back Regression. There are 3 variants of this loss: i) regression_mse: MSE of the predicted indices and ground truth indices. ii) regression_mse_var: MSE of the predicted indices that takes into account the variance of the similarities. This is important when the rate at which sequences go through different phases changes a lot. The variance scaling allows dynamic weighting of the MSE loss based on the similarities. iii) regression_huber: Huber loss between the predicted indices and ground truth indices.

Parameters

logits : torch.Tensor Pre-softmax similarity scores after cycling back to the starting sequence of shape (batch_size, num_steps) labels : torch.Tensor One hot labels containing the ground truth. The index where the cycle started is 1. Shape (batch_size, num_steps) num_steps : int Number of steps in the sequence embeddings steps : torch.Tensor Step indices/frame indices of the embeddings of the shape (batch_size, num_steps) seq_lens : torch.Tensor Lengths of the sequences from which the sampling was done. This can provide additional temporal information to the alignment loss. loss_type : str This specifies the kind of regression loss function. Currently supported loss functions: regression_mse, regression_mse_var, regression_huber. normalize_indices : bool If True, normalizes indices by sequence lengths. Useful for ensuring numerical instabilities don't arise as sequence indices can be large numbers. variance_lambda : float Weight of the variance of the similarity predictions while cycling back to the starting sequence.

Returns

loss : torch.Tensor A scalar loss calculated using a variant of regression.