dlc2action.loss.ms_tcn
Loss for the MS-TCN models
Adapted from https://github.com/sj-li/MS-TCN2
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# 7# Adapted from MS-TCN++ by yabufarha 8# Adapted from https://github.com/sj-li/MS-TCN2 9# Licensed under MIT License 10# 11""" 12Loss for the MS-TCN models 13 14Adapted from https://github.com/sj-li/MS-TCN2 15""" 16 17from torch import nn 18import torch.nn.functional as F 19import torch 20from collections.abc import Iterable 21from copy import copy 22import sys 23from typing import Optional 24 25 26class MS_TCN_Loss(nn.Module): 27 """ 28 The MS-TCN loss 29 Crossentropy + consistency loss (MSE over predicted probabilities) 30 """ 31 32 def __init__( 33 self, 34 num_classes: int, 35 weights: Iterable = None, 36 exclusive: bool = True, 37 ignore_index: int = -100, 38 focal: bool = False, 39 gamma: float = 1, 40 alpha: float = 0.15, 41 hard_negative_weight: float = 1, 42 ) -> None: 43 """ 44 Parameters 45 ---------- 46 num_classes : int 47 number of classes 48 weights : iterable, optional 49 class-wise cross-entropy weights 50 exclusive : bool, default True 51 True if single-label classification is used 52 ignore_index : int, default -100 53 the elements where target is equal to ignore_index will be ignored by cross-entropy 54 focal : bool, default False 55 if True, instead of regular cross-entropy the focal loss will be used 56 gamma : float, default 1 57 the gamma parameter of the focal loss 58 alpha : float, default 0.15 59 the weight of the consistency loss 60 hard_negative_weight : float, default 1 61 the weight assigned to the hard negative frames 62 """ 63 64 super().__init__() 65 self.weights = weights 66 self.num_classes = int(num_classes) 67 self.ignore_index = ignore_index 68 self.gamma = gamma 69 self.focal = focal 70 self.alpha = alpha 71 self.exclusive = exclusive 72 self.neg_weight = hard_negative_weight 73 if exclusive: 74 self.log_nl = lambda x: F.log_softmax(x, dim=1) 75 else: 76 self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7) 77 self.mse = nn.MSELoss(reduction="none") 78 if self.weights is not None: 79 self.need_init = True 80 else: 81 self.need_init = False 82 self._init_ce() 83 84 def _init_ce(self) -> None: 85 """ 86 Initialize cross-entropy function 87 """ 88 89 if self.exclusive: 90 if self.focal: 91 self.ce = nn.CrossEntropyLoss( 92 ignore_index=self.ignore_index, 93 weight=self.weights, 94 reduction="none", 95 ) 96 else: 97 self.ce = nn.CrossEntropyLoss( 98 ignore_index=self.ignore_index, weight=self.weights 99 ) 100 else: 101 self.ce = nn.BCEWithLogitsLoss(reduction="none") 102 103 def _init_weights(self, device: str) -> None: 104 """ 105 Initialize the weights vector and the cross-entropy function (after the device is known) 106 """ 107 108 if self.exclusive: 109 self.weights = torch.tensor(self.weights, device=device, dtype=torch.float) 110 else: 111 self.weights = { 112 k: torch.tensor(v, device=device, dtype=torch.float) 113 .unsqueeze(0) 114 .unsqueeze(-1) 115 for k, v in self.weights.items() 116 } 117 self._init_ce() 118 self.need_init = False 119 120 def _ce_loss(self, p: torch.Tensor, t: torch.Tensor) -> float: 121 """ 122 Apply cross-entropy loss 123 """ 124 125 if self.exclusive: 126 p = p.transpose(2, 1).contiguous().view(-1, self.num_classes) 127 t = t.view(-1) 128 mask = t != self.ignore_index 129 if torch.sum(mask) == 0: 130 return 0 131 if self.focal: 132 pr = F.softmax(p[mask], dim=1) 133 f = (1 - pr[range(torch.sum(mask)), t[mask]]) ** self.gamma 134 loss = (f * self.ce(p, t)[mask]).mean() 135 return loss 136 else: 137 loss = self.ce(p, t) 138 return loss 139 else: 140 if self.weights is not None: 141 weight0 = self.weights[0] 142 weight1 = self.weights[1] 143 else: 144 weight0 = 1 145 weight1 = 1 146 neg_mask = t == 2 147 target = copy(t) 148 target[neg_mask] = 0 149 loss = self.ce(p, target) 150 loss[t == 1] = (loss * weight1)[t == 1] 151 loss[t == 0] = (loss * weight0)[t == 0] 152 if self.neg_weight > 1: 153 loss[neg_mask] = (loss * self.neg_weight)[neg_mask] 154 elif self.neg_weight < 1: 155 inv_neg_mask = (~neg_mask) * (target == 0) 156 loss[inv_neg_mask] = (loss * self.neg_weight)[inv_neg_mask] 157 if self.focal: 158 pr = torch.sigmoid(p) 159 factor = target * ((1 - pr) ** self.gamma) + (1 - target) * ( 160 pr**self.gamma 161 ) 162 loss = loss * factor 163 loss = loss[target != self.ignore_index] 164 return loss.mean() if loss.size()[-1] != 0 else 0 165 166 def consistency_loss(self, p: torch.Tensor) -> float: 167 """ 168 Apply consistency loss 169 """ 170 171 mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1])) 172 clamp = torch.clamp(mse, min=0, max=16) 173 return torch.mean(clamp) 174 175 def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float: 176 """ 177 Compute the loss 178 Parameters 179 ---------- 180 predictions : torch.Tensor 181 a tensor of shape (#batch, #classes, #frames) 182 target : torch.Tensor 183 a tensor of shape (#batch, #classes, #frames) or (#batch, #frames) 184 Returns 185 ------- 186 loss : float 187 the loss value 188 """ 189 190 if self.need_init: 191 if isinstance(predictions, dict): 192 device = predictions["device"] 193 else: 194 device = predictions.device 195 self._init_weights(device) 196 loss = 0 197 if len(predictions.shape) == 4: 198 for p in predictions: 199 loss += self._ce_loss(p, target) 200 loss += self.alpha * self.consistency_loss(p) 201 else: 202 loss += self._ce_loss(predictions, target) 203 loss += self.alpha * self.consistency_loss(predictions) 204 return loss / len(predictions)
27class MS_TCN_Loss(nn.Module): 28 """ 29 The MS-TCN loss 30 Crossentropy + consistency loss (MSE over predicted probabilities) 31 """ 32 33 def __init__( 34 self, 35 num_classes: int, 36 weights: Iterable = None, 37 exclusive: bool = True, 38 ignore_index: int = -100, 39 focal: bool = False, 40 gamma: float = 1, 41 alpha: float = 0.15, 42 hard_negative_weight: float = 1, 43 ) -> None: 44 """ 45 Parameters 46 ---------- 47 num_classes : int 48 number of classes 49 weights : iterable, optional 50 class-wise cross-entropy weights 51 exclusive : bool, default True 52 True if single-label classification is used 53 ignore_index : int, default -100 54 the elements where target is equal to ignore_index will be ignored by cross-entropy 55 focal : bool, default False 56 if True, instead of regular cross-entropy the focal loss will be used 57 gamma : float, default 1 58 the gamma parameter of the focal loss 59 alpha : float, default 0.15 60 the weight of the consistency loss 61 hard_negative_weight : float, default 1 62 the weight assigned to the hard negative frames 63 """ 64 65 super().__init__() 66 self.weights = weights 67 self.num_classes = int(num_classes) 68 self.ignore_index = ignore_index 69 self.gamma = gamma 70 self.focal = focal 71 self.alpha = alpha 72 self.exclusive = exclusive 73 self.neg_weight = hard_negative_weight 74 if exclusive: 75 self.log_nl = lambda x: F.log_softmax(x, dim=1) 76 else: 77 self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7) 78 self.mse = nn.MSELoss(reduction="none") 79 if self.weights is not None: 80 self.need_init = True 81 else: 82 self.need_init = False 83 self._init_ce() 84 85 def _init_ce(self) -> None: 86 """ 87 Initialize cross-entropy function 88 """ 89 90 if self.exclusive: 91 if self.focal: 92 self.ce = nn.CrossEntropyLoss( 93 ignore_index=self.ignore_index, 94 weight=self.weights, 95 reduction="none", 96 ) 97 else: 98 self.ce = nn.CrossEntropyLoss( 99 ignore_index=self.ignore_index, weight=self.weights 100 ) 101 else: 102 self.ce = nn.BCEWithLogitsLoss(reduction="none") 103 104 def _init_weights(self, device: str) -> None: 105 """ 106 Initialize the weights vector and the cross-entropy function (after the device is known) 107 """ 108 109 if self.exclusive: 110 self.weights = torch.tensor(self.weights, device=device, dtype=torch.float) 111 else: 112 self.weights = { 113 k: torch.tensor(v, device=device, dtype=torch.float) 114 .unsqueeze(0) 115 .unsqueeze(-1) 116 for k, v in self.weights.items() 117 } 118 self._init_ce() 119 self.need_init = False 120 121 def _ce_loss(self, p: torch.Tensor, t: torch.Tensor) -> float: 122 """ 123 Apply cross-entropy loss 124 """ 125 126 if self.exclusive: 127 p = p.transpose(2, 1).contiguous().view(-1, self.num_classes) 128 t = t.view(-1) 129 mask = t != self.ignore_index 130 if torch.sum(mask) == 0: 131 return 0 132 if self.focal: 133 pr = F.softmax(p[mask], dim=1) 134 f = (1 - pr[range(torch.sum(mask)), t[mask]]) ** self.gamma 135 loss = (f * self.ce(p, t)[mask]).mean() 136 return loss 137 else: 138 loss = self.ce(p, t) 139 return loss 140 else: 141 if self.weights is not None: 142 weight0 = self.weights[0] 143 weight1 = self.weights[1] 144 else: 145 weight0 = 1 146 weight1 = 1 147 neg_mask = t == 2 148 target = copy(t) 149 target[neg_mask] = 0 150 loss = self.ce(p, target) 151 loss[t == 1] = (loss * weight1)[t == 1] 152 loss[t == 0] = (loss * weight0)[t == 0] 153 if self.neg_weight > 1: 154 loss[neg_mask] = (loss * self.neg_weight)[neg_mask] 155 elif self.neg_weight < 1: 156 inv_neg_mask = (~neg_mask) * (target == 0) 157 loss[inv_neg_mask] = (loss * self.neg_weight)[inv_neg_mask] 158 if self.focal: 159 pr = torch.sigmoid(p) 160 factor = target * ((1 - pr) ** self.gamma) + (1 - target) * ( 161 pr**self.gamma 162 ) 163 loss = loss * factor 164 loss = loss[target != self.ignore_index] 165 return loss.mean() if loss.size()[-1] != 0 else 0 166 167 def consistency_loss(self, p: torch.Tensor) -> float: 168 """ 169 Apply consistency loss 170 """ 171 172 mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1])) 173 clamp = torch.clamp(mse, min=0, max=16) 174 return torch.mean(clamp) 175 176 def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float: 177 """ 178 Compute the loss 179 Parameters 180 ---------- 181 predictions : torch.Tensor 182 a tensor of shape (#batch, #classes, #frames) 183 target : torch.Tensor 184 a tensor of shape (#batch, #classes, #frames) or (#batch, #frames) 185 Returns 186 ------- 187 loss : float 188 the loss value 189 """ 190 191 if self.need_init: 192 if isinstance(predictions, dict): 193 device = predictions["device"] 194 else: 195 device = predictions.device 196 self._init_weights(device) 197 loss = 0 198 if len(predictions.shape) == 4: 199 for p in predictions: 200 loss += self._ce_loss(p, target) 201 loss += self.alpha * self.consistency_loss(p) 202 else: 203 loss += self._ce_loss(predictions, target) 204 loss += self.alpha * self.consistency_loss(predictions) 205 return loss / len(predictions)
The MS-TCN loss Crossentropy + consistency loss (MSE over predicted probabilities)
33 def __init__( 34 self, 35 num_classes: int, 36 weights: Iterable = None, 37 exclusive: bool = True, 38 ignore_index: int = -100, 39 focal: bool = False, 40 gamma: float = 1, 41 alpha: float = 0.15, 42 hard_negative_weight: float = 1, 43 ) -> None: 44 """ 45 Parameters 46 ---------- 47 num_classes : int 48 number of classes 49 weights : iterable, optional 50 class-wise cross-entropy weights 51 exclusive : bool, default True 52 True if single-label classification is used 53 ignore_index : int, default -100 54 the elements where target is equal to ignore_index will be ignored by cross-entropy 55 focal : bool, default False 56 if True, instead of regular cross-entropy the focal loss will be used 57 gamma : float, default 1 58 the gamma parameter of the focal loss 59 alpha : float, default 0.15 60 the weight of the consistency loss 61 hard_negative_weight : float, default 1 62 the weight assigned to the hard negative frames 63 """ 64 65 super().__init__() 66 self.weights = weights 67 self.num_classes = int(num_classes) 68 self.ignore_index = ignore_index 69 self.gamma = gamma 70 self.focal = focal 71 self.alpha = alpha 72 self.exclusive = exclusive 73 self.neg_weight = hard_negative_weight 74 if exclusive: 75 self.log_nl = lambda x: F.log_softmax(x, dim=1) 76 else: 77 self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7) 78 self.mse = nn.MSELoss(reduction="none") 79 if self.weights is not None: 80 self.need_init = True 81 else: 82 self.need_init = False 83 self._init_ce()
Parameters
num_classes : int number of classes weights : iterable, optional class-wise cross-entropy weights exclusive : bool, default True True if single-label classification is used ignore_index : int, default -100 the elements where target is equal to ignore_index will be ignored by cross-entropy focal : bool, default False if True, instead of regular cross-entropy the focal loss will be used gamma : float, default 1 the gamma parameter of the focal loss alpha : float, default 0.15 the weight of the consistency loss hard_negative_weight : float, default 1 the weight assigned to the hard negative frames
167 def consistency_loss(self, p: torch.Tensor) -> float: 168 """ 169 Apply consistency loss 170 """ 171 172 mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1])) 173 clamp = torch.clamp(mse, min=0, max=16) 174 return torch.mean(clamp)
Apply consistency loss
176 def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float: 177 """ 178 Compute the loss 179 Parameters 180 ---------- 181 predictions : torch.Tensor 182 a tensor of shape (#batch, #classes, #frames) 183 target : torch.Tensor 184 a tensor of shape (#batch, #classes, #frames) or (#batch, #frames) 185 Returns 186 ------- 187 loss : float 188 the loss value 189 """ 190 191 if self.need_init: 192 if isinstance(predictions, dict): 193 device = predictions["device"] 194 else: 195 device = predictions.device 196 self._init_weights(device) 197 loss = 0 198 if len(predictions.shape) == 4: 199 for p in predictions: 200 loss += self._ce_loss(p, target) 201 loss += self.alpha * self.consistency_loss(p) 202 else: 203 loss += self._ce_loss(predictions, target) 204 loss += self.alpha * self.consistency_loss(predictions) 205 return loss / len(predictions)
Compute the loss
Parameters
predictions : torch.Tensor a tensor of shape (#batch, #classes, #frames) target : torch.Tensor a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)
Returns
loss : float the loss value
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr