dlc2action.loss.ms_tcn

Loss for the MS-TCN models.

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from MS-TCN++ by yabufarha
  7# Original work Copyright (c) 2019 June01
  8# Source: https://github.com/sj-li/MS-TCN2
  9# Originally licensed under MIT License
 10# Combined work licensed under GNU AGPLv3
 11#
 12"""Loss for the MS-TCN models."""
 13
 14import sys
 15from collections.abc import Iterable
 16from copy import copy
 17from typing import Optional
 18
 19import torch
 20import torch.nn.functional as F
 21from torch import nn
 22
 23
 24class MS_TCN_Loss(nn.Module):
 25    """The MS-TCN loss.
 26
 27    Crossentropy + consistency loss (MSE over predicted probabilities).
 28    """
 29
 30    def __init__(
 31        self,
 32        num_classes: int,
 33        weights: Iterable = None,
 34        exclusive: bool = True,
 35        ignore_index: int = -100,
 36        focal: bool = False,
 37        gamma: float = 1,
 38        alpha: float = 0.15,
 39        hard_negative_weight: float = 1,
 40    ) -> None:
 41        """Initialize the loss.
 42
 43        Parameters
 44        ----------
 45        num_classes : int
 46            number of classes
 47        weights : iterable, optional
 48            class-wise cross-entropy weights
 49        exclusive : bool, default True
 50            True if single-label classification is used
 51        ignore_index : int, default -100
 52            the elements where target is equal to ignore_index will be ignored by cross-entropy
 53        focal : bool, default False
 54            if True, instead of regular cross-entropy the focal loss will be used
 55        gamma : float, default 1
 56            the gamma parameter of the focal loss
 57        alpha : float, default 0.15
 58            the weight of the consistency loss
 59        hard_negative_weight : float, default 1
 60            the weight assigned to the hard negative frames
 61
 62        """
 63        super().__init__()
 64        self.weights = weights
 65        self.num_classes = int(num_classes)
 66        self.ignore_index = ignore_index
 67        self.gamma = gamma
 68        self.focal = focal
 69        self.alpha = alpha
 70        self.exclusive = exclusive
 71        self.neg_weight = hard_negative_weight
 72        if exclusive:
 73            self.log_nl = lambda x: F.log_softmax(x, dim=1)
 74        else:
 75            self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7)
 76        self.mse = nn.MSELoss(reduction="none")
 77        if self.weights is not None:
 78            self.need_init = True
 79        else:
 80            self.need_init = False
 81            self._init_ce()
 82
 83    def _init_ce(self) -> None:
 84        """Initialize cross-entropy function."""
 85        if self.exclusive:
 86            if self.focal:
 87                self.ce = nn.CrossEntropyLoss(
 88                    ignore_index=self.ignore_index,
 89                    weight=self.weights,
 90                    reduction="none",
 91                )
 92            else:
 93                self.ce = nn.CrossEntropyLoss(
 94                    ignore_index=self.ignore_index, weight=self.weights
 95                )
 96        else:
 97            self.ce = nn.BCEWithLogitsLoss(reduction="none")
 98
 99    def _init_weights(self, device: str) -> None:
100        """Initialize the weights vector and the cross-entropy function (after the device is known)."""
101        if self.exclusive:
102            self.weights = torch.tensor(self.weights, device=device, dtype=torch.float)
103        else:
104            self.weights = {
105                k: torch.tensor(v, device=device, dtype=torch.float)
106                .unsqueeze(0)
107                .unsqueeze(-1)
108                for k, v in self.weights.items()
109            }
110        self._init_ce()
111        self.need_init = False
112
113    def _ce_loss(self, p: torch.Tensor, t: torch.Tensor) -> float:
114        """Apply cross-entropy loss."""
115        if self.exclusive:
116            t = t.long()
117            p = p.transpose(2, 1).contiguous().view(-1, self.num_classes)
118            t = t.view(-1)
119            mask = t != self.ignore_index
120            if torch.sum(mask) == 0:
121                return 0
122            if self.focal:
123                pr = F.softmax(p[mask], dim=1)
124                f = (1 - pr[range(torch.sum(mask)), t[mask]]) ** self.gamma
125                f = (1 - pr[range(torch.sum(mask)), t[mask].long()]) ** self.gamma
126                loss = (f * self.ce(p, t)[mask]).mean()
127                return loss
128            else:
129                loss = self.ce(p, t)
130                return loss
131        else:
132            if self.weights is not None:
133                weight0 = self.weights[0]
134                weight1 = self.weights[1]
135            else:
136                weight0 = 1
137                weight1 = 1
138            neg_mask = t == 2
139            target = copy(t)
140            target[neg_mask] = 0
141            loss = self.ce(p, target)
142            loss[t == 1] = (loss * weight1)[t == 1]
143            loss[t == 0] = (loss * weight0)[t == 0]
144            if self.neg_weight > 1:
145                loss[neg_mask] = (loss * self.neg_weight)[neg_mask]
146            elif self.neg_weight < 1:
147                inv_neg_mask = (~neg_mask) * (target == 0)
148                loss[inv_neg_mask] = (loss * self.neg_weight)[inv_neg_mask]
149            if self.focal:
150                pr = torch.sigmoid(p)
151                factor = target * ((1 - pr) ** self.gamma) + (1 - target) * (
152                    pr**self.gamma
153                )
154                loss = loss * factor
155            loss = loss[target != self.ignore_index]
156            return loss.mean() if loss.size()[-1] != 0 else 0
157
158    def consistency_loss(self, p: torch.Tensor) -> float:
159        """Apply consistency loss."""
160        mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1]))
161        clamp = torch.clamp(mse, min=0, max=16)
162        return torch.mean(clamp)
163
164    def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float:
165        """Compute the loss.
166
167        Parameters
168        ----------
169        predictions : torch.Tensor
170            a tensor of shape (#batch, #classes, #frames)
171        target : torch.Tensor
172            a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)
173
174        Returns
175        -------
176        loss : float
177            the loss value
178
179        """
180        if self.need_init:
181            if isinstance(predictions, dict):
182                device = predictions["device"]
183            else:
184                device = predictions.device
185            self._init_weights(device)
186        loss = 0
187        if len(predictions.shape) == 4:
188            for p in predictions:
189                loss += self._ce_loss(p, target)
190                loss += self.alpha * self.consistency_loss(p)
191        else:
192            loss += self._ce_loss(predictions, target)
193            loss += self.alpha * self.consistency_loss(predictions)
194        return loss / len(predictions)
class MS_TCN_Loss(torch.nn.modules.module.Module):
 25class MS_TCN_Loss(nn.Module):
 26    """The MS-TCN loss.
 27
 28    Crossentropy + consistency loss (MSE over predicted probabilities).
 29    """
 30
 31    def __init__(
 32        self,
 33        num_classes: int,
 34        weights: Iterable = None,
 35        exclusive: bool = True,
 36        ignore_index: int = -100,
 37        focal: bool = False,
 38        gamma: float = 1,
 39        alpha: float = 0.15,
 40        hard_negative_weight: float = 1,
 41    ) -> None:
 42        """Initialize the loss.
 43
 44        Parameters
 45        ----------
 46        num_classes : int
 47            number of classes
 48        weights : iterable, optional
 49            class-wise cross-entropy weights
 50        exclusive : bool, default True
 51            True if single-label classification is used
 52        ignore_index : int, default -100
 53            the elements where target is equal to ignore_index will be ignored by cross-entropy
 54        focal : bool, default False
 55            if True, instead of regular cross-entropy the focal loss will be used
 56        gamma : float, default 1
 57            the gamma parameter of the focal loss
 58        alpha : float, default 0.15
 59            the weight of the consistency loss
 60        hard_negative_weight : float, default 1
 61            the weight assigned to the hard negative frames
 62
 63        """
 64        super().__init__()
 65        self.weights = weights
 66        self.num_classes = int(num_classes)
 67        self.ignore_index = ignore_index
 68        self.gamma = gamma
 69        self.focal = focal
 70        self.alpha = alpha
 71        self.exclusive = exclusive
 72        self.neg_weight = hard_negative_weight
 73        if exclusive:
 74            self.log_nl = lambda x: F.log_softmax(x, dim=1)
 75        else:
 76            self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7)
 77        self.mse = nn.MSELoss(reduction="none")
 78        if self.weights is not None:
 79            self.need_init = True
 80        else:
 81            self.need_init = False
 82            self._init_ce()
 83
 84    def _init_ce(self) -> None:
 85        """Initialize cross-entropy function."""
 86        if self.exclusive:
 87            if self.focal:
 88                self.ce = nn.CrossEntropyLoss(
 89                    ignore_index=self.ignore_index,
 90                    weight=self.weights,
 91                    reduction="none",
 92                )
 93            else:
 94                self.ce = nn.CrossEntropyLoss(
 95                    ignore_index=self.ignore_index, weight=self.weights
 96                )
 97        else:
 98            self.ce = nn.BCEWithLogitsLoss(reduction="none")
 99
100    def _init_weights(self, device: str) -> None:
101        """Initialize the weights vector and the cross-entropy function (after the device is known)."""
102        if self.exclusive:
103            self.weights = torch.tensor(self.weights, device=device, dtype=torch.float)
104        else:
105            self.weights = {
106                k: torch.tensor(v, device=device, dtype=torch.float)
107                .unsqueeze(0)
108                .unsqueeze(-1)
109                for k, v in self.weights.items()
110            }
111        self._init_ce()
112        self.need_init = False
113
114    def _ce_loss(self, p: torch.Tensor, t: torch.Tensor) -> float:
115        """Apply cross-entropy loss."""
116        if self.exclusive:
117            t = t.long()
118            p = p.transpose(2, 1).contiguous().view(-1, self.num_classes)
119            t = t.view(-1)
120            mask = t != self.ignore_index
121            if torch.sum(mask) == 0:
122                return 0
123            if self.focal:
124                pr = F.softmax(p[mask], dim=1)
125                f = (1 - pr[range(torch.sum(mask)), t[mask]]) ** self.gamma
126                f = (1 - pr[range(torch.sum(mask)), t[mask].long()]) ** self.gamma
127                loss = (f * self.ce(p, t)[mask]).mean()
128                return loss
129            else:
130                loss = self.ce(p, t)
131                return loss
132        else:
133            if self.weights is not None:
134                weight0 = self.weights[0]
135                weight1 = self.weights[1]
136            else:
137                weight0 = 1
138                weight1 = 1
139            neg_mask = t == 2
140            target = copy(t)
141            target[neg_mask] = 0
142            loss = self.ce(p, target)
143            loss[t == 1] = (loss * weight1)[t == 1]
144            loss[t == 0] = (loss * weight0)[t == 0]
145            if self.neg_weight > 1:
146                loss[neg_mask] = (loss * self.neg_weight)[neg_mask]
147            elif self.neg_weight < 1:
148                inv_neg_mask = (~neg_mask) * (target == 0)
149                loss[inv_neg_mask] = (loss * self.neg_weight)[inv_neg_mask]
150            if self.focal:
151                pr = torch.sigmoid(p)
152                factor = target * ((1 - pr) ** self.gamma) + (1 - target) * (
153                    pr**self.gamma
154                )
155                loss = loss * factor
156            loss = loss[target != self.ignore_index]
157            return loss.mean() if loss.size()[-1] != 0 else 0
158
159    def consistency_loss(self, p: torch.Tensor) -> float:
160        """Apply consistency loss."""
161        mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1]))
162        clamp = torch.clamp(mse, min=0, max=16)
163        return torch.mean(clamp)
164
165    def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float:
166        """Compute the loss.
167
168        Parameters
169        ----------
170        predictions : torch.Tensor
171            a tensor of shape (#batch, #classes, #frames)
172        target : torch.Tensor
173            a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)
174
175        Returns
176        -------
177        loss : float
178            the loss value
179
180        """
181        if self.need_init:
182            if isinstance(predictions, dict):
183                device = predictions["device"]
184            else:
185                device = predictions.device
186            self._init_weights(device)
187        loss = 0
188        if len(predictions.shape) == 4:
189            for p in predictions:
190                loss += self._ce_loss(p, target)
191                loss += self.alpha * self.consistency_loss(p)
192        else:
193            loss += self._ce_loss(predictions, target)
194            loss += self.alpha * self.consistency_loss(predictions)
195        return loss / len(predictions)

The MS-TCN loss.

Crossentropy + consistency loss (MSE over predicted probabilities).

MS_TCN_Loss( num_classes: int, weights: Iterable = None, exclusive: bool = True, ignore_index: int = -100, focal: bool = False, gamma: float = 1, alpha: float = 0.15, hard_negative_weight: float = 1)
31    def __init__(
32        self,
33        num_classes: int,
34        weights: Iterable = None,
35        exclusive: bool = True,
36        ignore_index: int = -100,
37        focal: bool = False,
38        gamma: float = 1,
39        alpha: float = 0.15,
40        hard_negative_weight: float = 1,
41    ) -> None:
42        """Initialize the loss.
43
44        Parameters
45        ----------
46        num_classes : int
47            number of classes
48        weights : iterable, optional
49            class-wise cross-entropy weights
50        exclusive : bool, default True
51            True if single-label classification is used
52        ignore_index : int, default -100
53            the elements where target is equal to ignore_index will be ignored by cross-entropy
54        focal : bool, default False
55            if True, instead of regular cross-entropy the focal loss will be used
56        gamma : float, default 1
57            the gamma parameter of the focal loss
58        alpha : float, default 0.15
59            the weight of the consistency loss
60        hard_negative_weight : float, default 1
61            the weight assigned to the hard negative frames
62
63        """
64        super().__init__()
65        self.weights = weights
66        self.num_classes = int(num_classes)
67        self.ignore_index = ignore_index
68        self.gamma = gamma
69        self.focal = focal
70        self.alpha = alpha
71        self.exclusive = exclusive
72        self.neg_weight = hard_negative_weight
73        if exclusive:
74            self.log_nl = lambda x: F.log_softmax(x, dim=1)
75        else:
76            self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7)
77        self.mse = nn.MSELoss(reduction="none")
78        if self.weights is not None:
79            self.need_init = True
80        else:
81            self.need_init = False
82            self._init_ce()

Initialize the loss.

Parameters

num_classes : int number of classes weights : iterable, optional class-wise cross-entropy weights exclusive : bool, default True True if single-label classification is used ignore_index : int, default -100 the elements where target is equal to ignore_index will be ignored by cross-entropy focal : bool, default False if True, instead of regular cross-entropy the focal loss will be used gamma : float, default 1 the gamma parameter of the focal loss alpha : float, default 0.15 the weight of the consistency loss hard_negative_weight : float, default 1 the weight assigned to the hard negative frames

weights
num_classes
ignore_index
gamma
focal
alpha
exclusive
neg_weight
mse
def consistency_loss(self, p: torch.Tensor) -> float:
159    def consistency_loss(self, p: torch.Tensor) -> float:
160        """Apply consistency loss."""
161        mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1]))
162        clamp = torch.clamp(mse, min=0, max=16)
163        return torch.mean(clamp)

Apply consistency loss.

def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float:
165    def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float:
166        """Compute the loss.
167
168        Parameters
169        ----------
170        predictions : torch.Tensor
171            a tensor of shape (#batch, #classes, #frames)
172        target : torch.Tensor
173            a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)
174
175        Returns
176        -------
177        loss : float
178            the loss value
179
180        """
181        if self.need_init:
182            if isinstance(predictions, dict):
183                device = predictions["device"]
184            else:
185                device = predictions.device
186            self._init_weights(device)
187        loss = 0
188        if len(predictions.shape) == 4:
189            for p in predictions:
190                loss += self._ce_loss(p, target)
191                loss += self.alpha * self.consistency_loss(p)
192        else:
193            loss += self._ce_loss(predictions, target)
194            loss += self.alpha * self.consistency_loss(predictions)
195        return loss / len(predictions)

Compute the loss.

Parameters

predictions : torch.Tensor a tensor of shape (#batch, #classes, #frames) target : torch.Tensor a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)

Returns

loss : float the loss value