dlc2action.loss.ms_tcn

Loss for the MS-TCN models

Adapted from https://github.com/sj-li/MS-TCN2

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6#
  7# Adapted from MS-TCN++ by yabufarha
  8# Adapted from https://github.com/sj-li/MS-TCN2
  9# Licensed under MIT License
 10#
 11"""
 12Loss for the MS-TCN models
 13
 14Adapted from https://github.com/sj-li/MS-TCN2
 15"""
 16
 17from torch import nn
 18import torch.nn.functional as F
 19import torch
 20from collections.abc import Iterable
 21from copy import copy
 22import sys
 23from typing import Optional
 24
 25
 26class MS_TCN_Loss(nn.Module):
 27    """
 28    The MS-TCN loss
 29    Crossentropy + consistency loss (MSE over predicted probabilities)
 30    """
 31
 32    def __init__(
 33        self,
 34        num_classes: int,
 35        weights: Iterable = None,
 36        exclusive: bool = True,
 37        ignore_index: int = -100,
 38        focal: bool = False,
 39        gamma: float = 1,
 40        alpha: float = 0.15,
 41        hard_negative_weight: float = 1,
 42    ) -> None:
 43        """
 44        Parameters
 45        ----------
 46        num_classes : int
 47            number of classes
 48        weights : iterable, optional
 49            class-wise cross-entropy weights
 50        exclusive : bool, default True
 51            True if single-label classification is used
 52        ignore_index : int, default -100
 53            the elements where target is equal to ignore_index will be ignored by cross-entropy
 54        focal : bool, default False
 55            if True, instead of regular cross-entropy the focal loss will be used
 56        gamma : float, default 1
 57            the gamma parameter of the focal loss
 58        alpha : float, default 0.15
 59            the weight of the consistency loss
 60        hard_negative_weight : float, default 1
 61            the weight assigned to the hard negative frames
 62        """
 63
 64        super().__init__()
 65        self.weights = weights
 66        self.num_classes = int(num_classes)
 67        self.ignore_index = ignore_index
 68        self.gamma = gamma
 69        self.focal = focal
 70        self.alpha = alpha
 71        self.exclusive = exclusive
 72        self.neg_weight = hard_negative_weight
 73        if exclusive:
 74            self.log_nl = lambda x: F.log_softmax(x, dim=1)
 75        else:
 76            self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7)
 77        self.mse = nn.MSELoss(reduction="none")
 78        if self.weights is not None:
 79            self.need_init = True
 80        else:
 81            self.need_init = False
 82            self._init_ce()
 83
 84    def _init_ce(self) -> None:
 85        """
 86        Initialize cross-entropy function
 87        """
 88
 89        if self.exclusive:
 90            if self.focal:
 91                self.ce = nn.CrossEntropyLoss(
 92                    ignore_index=self.ignore_index,
 93                    weight=self.weights,
 94                    reduction="none",
 95                )
 96            else:
 97                self.ce = nn.CrossEntropyLoss(
 98                    ignore_index=self.ignore_index, weight=self.weights
 99                )
100        else:
101            self.ce = nn.BCEWithLogitsLoss(reduction="none")
102
103    def _init_weights(self, device: str) -> None:
104        """
105        Initialize the weights vector and the cross-entropy function (after the device is known)
106        """
107
108        if self.exclusive:
109            self.weights = torch.tensor(self.weights, device=device, dtype=torch.float)
110        else:
111            self.weights = {
112                k: torch.tensor(v, device=device, dtype=torch.float)
113                .unsqueeze(0)
114                .unsqueeze(-1)
115                for k, v in self.weights.items()
116            }
117        self._init_ce()
118        self.need_init = False
119
120    def _ce_loss(self, p: torch.Tensor, t: torch.Tensor) -> float:
121        """
122        Apply cross-entropy loss
123        """
124
125        if self.exclusive:
126            p = p.transpose(2, 1).contiguous().view(-1, self.num_classes)
127            t = t.view(-1)
128            mask = t != self.ignore_index
129            if torch.sum(mask) == 0:
130                return 0
131            if self.focal:
132                pr = F.softmax(p[mask], dim=1)
133                f = (1 - pr[range(torch.sum(mask)), t[mask]]) ** self.gamma
134                loss = (f * self.ce(p, t)[mask]).mean()
135                return loss
136            else:
137                loss = self.ce(p, t)
138                return loss
139        else:
140            if self.weights is not None:
141                weight0 = self.weights[0]
142                weight1 = self.weights[1]
143            else:
144                weight0 = 1
145                weight1 = 1
146            neg_mask = t == 2
147            target = copy(t)
148            target[neg_mask] = 0
149            loss = self.ce(p, target)
150            loss[t == 1] = (loss * weight1)[t == 1]
151            loss[t == 0] = (loss * weight0)[t == 0]
152            if self.neg_weight > 1:
153                loss[neg_mask] = (loss * self.neg_weight)[neg_mask]
154            elif self.neg_weight < 1:
155                inv_neg_mask = (~neg_mask) * (target == 0)
156                loss[inv_neg_mask] = (loss * self.neg_weight)[inv_neg_mask]
157            if self.focal:
158                pr = torch.sigmoid(p)
159                factor = target * ((1 - pr) ** self.gamma) + (1 - target) * (
160                    pr**self.gamma
161                )
162                loss = loss * factor
163            loss = loss[target != self.ignore_index]
164            return loss.mean() if loss.size()[-1] != 0 else 0
165
166    def consistency_loss(self, p: torch.Tensor) -> float:
167        """
168        Apply consistency loss
169        """
170
171        mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1]))
172        clamp = torch.clamp(mse, min=0, max=16)
173        return torch.mean(clamp)
174
175    def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float:
176        """
177        Compute the loss
178        Parameters
179        ----------
180        predictions : torch.Tensor
181            a tensor of shape (#batch, #classes, #frames)
182        target : torch.Tensor
183            a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)
184        Returns
185        -------
186        loss : float
187            the loss value
188        """
189
190        if self.need_init:
191            if isinstance(predictions, dict):
192                device = predictions["device"]
193            else:
194                device = predictions.device
195            self._init_weights(device)
196        loss = 0
197        if len(predictions.shape) == 4:
198            for p in predictions:
199                loss += self._ce_loss(p, target)
200                loss += self.alpha * self.consistency_loss(p)
201        else:
202            loss += self._ce_loss(predictions, target)
203            loss += self.alpha * self.consistency_loss(predictions)
204        return loss / len(predictions)
class MS_TCN_Loss(torch.nn.modules.module.Module):
 27class MS_TCN_Loss(nn.Module):
 28    """
 29    The MS-TCN loss
 30    Crossentropy + consistency loss (MSE over predicted probabilities)
 31    """
 32
 33    def __init__(
 34        self,
 35        num_classes: int,
 36        weights: Iterable = None,
 37        exclusive: bool = True,
 38        ignore_index: int = -100,
 39        focal: bool = False,
 40        gamma: float = 1,
 41        alpha: float = 0.15,
 42        hard_negative_weight: float = 1,
 43    ) -> None:
 44        """
 45        Parameters
 46        ----------
 47        num_classes : int
 48            number of classes
 49        weights : iterable, optional
 50            class-wise cross-entropy weights
 51        exclusive : bool, default True
 52            True if single-label classification is used
 53        ignore_index : int, default -100
 54            the elements where target is equal to ignore_index will be ignored by cross-entropy
 55        focal : bool, default False
 56            if True, instead of regular cross-entropy the focal loss will be used
 57        gamma : float, default 1
 58            the gamma parameter of the focal loss
 59        alpha : float, default 0.15
 60            the weight of the consistency loss
 61        hard_negative_weight : float, default 1
 62            the weight assigned to the hard negative frames
 63        """
 64
 65        super().__init__()
 66        self.weights = weights
 67        self.num_classes = int(num_classes)
 68        self.ignore_index = ignore_index
 69        self.gamma = gamma
 70        self.focal = focal
 71        self.alpha = alpha
 72        self.exclusive = exclusive
 73        self.neg_weight = hard_negative_weight
 74        if exclusive:
 75            self.log_nl = lambda x: F.log_softmax(x, dim=1)
 76        else:
 77            self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7)
 78        self.mse = nn.MSELoss(reduction="none")
 79        if self.weights is not None:
 80            self.need_init = True
 81        else:
 82            self.need_init = False
 83            self._init_ce()
 84
 85    def _init_ce(self) -> None:
 86        """
 87        Initialize cross-entropy function
 88        """
 89
 90        if self.exclusive:
 91            if self.focal:
 92                self.ce = nn.CrossEntropyLoss(
 93                    ignore_index=self.ignore_index,
 94                    weight=self.weights,
 95                    reduction="none",
 96                )
 97            else:
 98                self.ce = nn.CrossEntropyLoss(
 99                    ignore_index=self.ignore_index, weight=self.weights
100                )
101        else:
102            self.ce = nn.BCEWithLogitsLoss(reduction="none")
103
104    def _init_weights(self, device: str) -> None:
105        """
106        Initialize the weights vector and the cross-entropy function (after the device is known)
107        """
108
109        if self.exclusive:
110            self.weights = torch.tensor(self.weights, device=device, dtype=torch.float)
111        else:
112            self.weights = {
113                k: torch.tensor(v, device=device, dtype=torch.float)
114                .unsqueeze(0)
115                .unsqueeze(-1)
116                for k, v in self.weights.items()
117            }
118        self._init_ce()
119        self.need_init = False
120
121    def _ce_loss(self, p: torch.Tensor, t: torch.Tensor) -> float:
122        """
123        Apply cross-entropy loss
124        """
125
126        if self.exclusive:
127            p = p.transpose(2, 1).contiguous().view(-1, self.num_classes)
128            t = t.view(-1)
129            mask = t != self.ignore_index
130            if torch.sum(mask) == 0:
131                return 0
132            if self.focal:
133                pr = F.softmax(p[mask], dim=1)
134                f = (1 - pr[range(torch.sum(mask)), t[mask]]) ** self.gamma
135                loss = (f * self.ce(p, t)[mask]).mean()
136                return loss
137            else:
138                loss = self.ce(p, t)
139                return loss
140        else:
141            if self.weights is not None:
142                weight0 = self.weights[0]
143                weight1 = self.weights[1]
144            else:
145                weight0 = 1
146                weight1 = 1
147            neg_mask = t == 2
148            target = copy(t)
149            target[neg_mask] = 0
150            loss = self.ce(p, target)
151            loss[t == 1] = (loss * weight1)[t == 1]
152            loss[t == 0] = (loss * weight0)[t == 0]
153            if self.neg_weight > 1:
154                loss[neg_mask] = (loss * self.neg_weight)[neg_mask]
155            elif self.neg_weight < 1:
156                inv_neg_mask = (~neg_mask) * (target == 0)
157                loss[inv_neg_mask] = (loss * self.neg_weight)[inv_neg_mask]
158            if self.focal:
159                pr = torch.sigmoid(p)
160                factor = target * ((1 - pr) ** self.gamma) + (1 - target) * (
161                    pr**self.gamma
162                )
163                loss = loss * factor
164            loss = loss[target != self.ignore_index]
165            return loss.mean() if loss.size()[-1] != 0 else 0
166
167    def consistency_loss(self, p: torch.Tensor) -> float:
168        """
169        Apply consistency loss
170        """
171
172        mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1]))
173        clamp = torch.clamp(mse, min=0, max=16)
174        return torch.mean(clamp)
175
176    def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float:
177        """
178        Compute the loss
179        Parameters
180        ----------
181        predictions : torch.Tensor
182            a tensor of shape (#batch, #classes, #frames)
183        target : torch.Tensor
184            a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)
185        Returns
186        -------
187        loss : float
188            the loss value
189        """
190
191        if self.need_init:
192            if isinstance(predictions, dict):
193                device = predictions["device"]
194            else:
195                device = predictions.device
196            self._init_weights(device)
197        loss = 0
198        if len(predictions.shape) == 4:
199            for p in predictions:
200                loss += self._ce_loss(p, target)
201                loss += self.alpha * self.consistency_loss(p)
202        else:
203            loss += self._ce_loss(predictions, target)
204            loss += self.alpha * self.consistency_loss(predictions)
205        return loss / len(predictions)

The MS-TCN loss Crossentropy + consistency loss (MSE over predicted probabilities)

MS_TCN_Loss( num_classes: int, weights: collections.abc.Iterable = None, exclusive: bool = True, ignore_index: int = -100, focal: bool = False, gamma: float = 1, alpha: float = 0.15, hard_negative_weight: float = 1)
33    def __init__(
34        self,
35        num_classes: int,
36        weights: Iterable = None,
37        exclusive: bool = True,
38        ignore_index: int = -100,
39        focal: bool = False,
40        gamma: float = 1,
41        alpha: float = 0.15,
42        hard_negative_weight: float = 1,
43    ) -> None:
44        """
45        Parameters
46        ----------
47        num_classes : int
48            number of classes
49        weights : iterable, optional
50            class-wise cross-entropy weights
51        exclusive : bool, default True
52            True if single-label classification is used
53        ignore_index : int, default -100
54            the elements where target is equal to ignore_index will be ignored by cross-entropy
55        focal : bool, default False
56            if True, instead of regular cross-entropy the focal loss will be used
57        gamma : float, default 1
58            the gamma parameter of the focal loss
59        alpha : float, default 0.15
60            the weight of the consistency loss
61        hard_negative_weight : float, default 1
62            the weight assigned to the hard negative frames
63        """
64
65        super().__init__()
66        self.weights = weights
67        self.num_classes = int(num_classes)
68        self.ignore_index = ignore_index
69        self.gamma = gamma
70        self.focal = focal
71        self.alpha = alpha
72        self.exclusive = exclusive
73        self.neg_weight = hard_negative_weight
74        if exclusive:
75            self.log_nl = lambda x: F.log_softmax(x, dim=1)
76        else:
77            self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7)
78        self.mse = nn.MSELoss(reduction="none")
79        if self.weights is not None:
80            self.need_init = True
81        else:
82            self.need_init = False
83            self._init_ce()

Parameters

num_classes : int number of classes weights : iterable, optional class-wise cross-entropy weights exclusive : bool, default True True if single-label classification is used ignore_index : int, default -100 the elements where target is equal to ignore_index will be ignored by cross-entropy focal : bool, default False if True, instead of regular cross-entropy the focal loss will be used gamma : float, default 1 the gamma parameter of the focal loss alpha : float, default 0.15 the weight of the consistency loss hard_negative_weight : float, default 1 the weight assigned to the hard negative frames

def consistency_loss(self, p: torch.Tensor) -> float:
167    def consistency_loss(self, p: torch.Tensor) -> float:
168        """
169        Apply consistency loss
170        """
171
172        mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1]))
173        clamp = torch.clamp(mse, min=0, max=16)
174        return torch.mean(clamp)

Apply consistency loss

def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float:
176    def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float:
177        """
178        Compute the loss
179        Parameters
180        ----------
181        predictions : torch.Tensor
182            a tensor of shape (#batch, #classes, #frames)
183        target : torch.Tensor
184            a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)
185        Returns
186        -------
187        loss : float
188            the loss value
189        """
190
191        if self.need_init:
192            if isinstance(predictions, dict):
193                device = predictions["device"]
194            else:
195                device = predictions.device
196            self._init_weights(device)
197        loss = 0
198        if len(predictions.shape) == 4:
199            for p in predictions:
200                loss += self._ce_loss(p, target)
201                loss += self.alpha * self.consistency_loss(p)
202        else:
203            loss += self._ce_loss(predictions, target)
204            loss += self.alpha * self.consistency_loss(predictions)
205        return loss / len(predictions)

Compute the loss

Parameters

predictions : torch.Tensor a tensor of shape (#batch, #classes, #frames) target : torch.Tensor a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)

Returns

loss : float the loss value

Inherited Members
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr