dlc2action.loss.ms_tcn
Loss for the MS-TCN models.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# Incorporates code adapted from MS-TCN++ by yabufarha 7# Original work Copyright (c) 2019 June01 8# Source: https://github.com/sj-li/MS-TCN2 9# Originally licensed under MIT License 10# Combined work licensed under GNU AGPLv3 11# 12"""Loss for the MS-TCN models.""" 13 14import sys 15from collections.abc import Iterable 16from copy import copy 17from typing import Optional 18 19import torch 20import torch.nn.functional as F 21from torch import nn 22 23 24class MS_TCN_Loss(nn.Module): 25 """The MS-TCN loss. 26 27 Crossentropy + consistency loss (MSE over predicted probabilities). 28 """ 29 30 def __init__( 31 self, 32 num_classes: int, 33 weights: Iterable = None, 34 exclusive: bool = True, 35 ignore_index: int = -100, 36 focal: bool = False, 37 gamma: float = 1, 38 alpha: float = 0.15, 39 hard_negative_weight: float = 1, 40 ) -> None: 41 """Initialize the loss. 42 43 Parameters 44 ---------- 45 num_classes : int 46 number of classes 47 weights : iterable, optional 48 class-wise cross-entropy weights 49 exclusive : bool, default True 50 True if single-label classification is used 51 ignore_index : int, default -100 52 the elements where target is equal to ignore_index will be ignored by cross-entropy 53 focal : bool, default False 54 if True, instead of regular cross-entropy the focal loss will be used 55 gamma : float, default 1 56 the gamma parameter of the focal loss 57 alpha : float, default 0.15 58 the weight of the consistency loss 59 hard_negative_weight : float, default 1 60 the weight assigned to the hard negative frames 61 62 """ 63 super().__init__() 64 self.weights = weights 65 self.num_classes = int(num_classes) 66 self.ignore_index = ignore_index 67 self.gamma = gamma 68 self.focal = focal 69 self.alpha = alpha 70 self.exclusive = exclusive 71 self.neg_weight = hard_negative_weight 72 if exclusive: 73 self.log_nl = lambda x: F.log_softmax(x, dim=1) 74 else: 75 self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7) 76 self.mse = nn.MSELoss(reduction="none") 77 if self.weights is not None: 78 self.need_init = True 79 else: 80 self.need_init = False 81 self._init_ce() 82 83 def _init_ce(self) -> None: 84 """Initialize cross-entropy function.""" 85 if self.exclusive: 86 if self.focal: 87 self.ce = nn.CrossEntropyLoss( 88 ignore_index=self.ignore_index, 89 weight=self.weights, 90 reduction="none", 91 ) 92 else: 93 self.ce = nn.CrossEntropyLoss( 94 ignore_index=self.ignore_index, weight=self.weights 95 ) 96 else: 97 self.ce = nn.BCEWithLogitsLoss(reduction="none") 98 99 def _init_weights(self, device: str) -> None: 100 """Initialize the weights vector and the cross-entropy function (after the device is known).""" 101 if self.exclusive: 102 self.weights = torch.tensor(self.weights, device=device, dtype=torch.float) 103 else: 104 self.weights = { 105 k: torch.tensor(v, device=device, dtype=torch.float) 106 .unsqueeze(0) 107 .unsqueeze(-1) 108 for k, v in self.weights.items() 109 } 110 self._init_ce() 111 self.need_init = False 112 113 def _ce_loss(self, p: torch.Tensor, t: torch.Tensor) -> float: 114 """Apply cross-entropy loss.""" 115 if self.exclusive: 116 t = t.long() 117 p = p.transpose(2, 1).contiguous().view(-1, self.num_classes) 118 t = t.view(-1) 119 mask = t != self.ignore_index 120 if torch.sum(mask) == 0: 121 return 0 122 if self.focal: 123 pr = F.softmax(p[mask], dim=1) 124 f = (1 - pr[range(torch.sum(mask)), t[mask]]) ** self.gamma 125 f = (1 - pr[range(torch.sum(mask)), t[mask].long()]) ** self.gamma 126 loss = (f * self.ce(p, t)[mask]).mean() 127 return loss 128 else: 129 loss = self.ce(p, t) 130 return loss 131 else: 132 if self.weights is not None: 133 weight0 = self.weights[0] 134 weight1 = self.weights[1] 135 else: 136 weight0 = 1 137 weight1 = 1 138 neg_mask = t == 2 139 target = copy(t) 140 target[neg_mask] = 0 141 loss = self.ce(p, target) 142 loss[t == 1] = (loss * weight1)[t == 1] 143 loss[t == 0] = (loss * weight0)[t == 0] 144 if self.neg_weight > 1: 145 loss[neg_mask] = (loss * self.neg_weight)[neg_mask] 146 elif self.neg_weight < 1: 147 inv_neg_mask = (~neg_mask) * (target == 0) 148 loss[inv_neg_mask] = (loss * self.neg_weight)[inv_neg_mask] 149 if self.focal: 150 pr = torch.sigmoid(p) 151 factor = target * ((1 - pr) ** self.gamma) + (1 - target) * ( 152 pr**self.gamma 153 ) 154 loss = loss * factor 155 loss = loss[target != self.ignore_index] 156 return loss.mean() if loss.size()[-1] != 0 else 0 157 158 def consistency_loss(self, p: torch.Tensor) -> float: 159 """Apply consistency loss.""" 160 mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1])) 161 clamp = torch.clamp(mse, min=0, max=16) 162 return torch.mean(clamp) 163 164 def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float: 165 """Compute the loss. 166 167 Parameters 168 ---------- 169 predictions : torch.Tensor 170 a tensor of shape (#batch, #classes, #frames) 171 target : torch.Tensor 172 a tensor of shape (#batch, #classes, #frames) or (#batch, #frames) 173 174 Returns 175 ------- 176 loss : float 177 the loss value 178 179 """ 180 if self.need_init: 181 if isinstance(predictions, dict): 182 device = predictions["device"] 183 else: 184 device = predictions.device 185 self._init_weights(device) 186 loss = 0 187 if len(predictions.shape) == 4: 188 for p in predictions: 189 loss += self._ce_loss(p, target) 190 loss += self.alpha * self.consistency_loss(p) 191 else: 192 loss += self._ce_loss(predictions, target) 193 loss += self.alpha * self.consistency_loss(predictions) 194 return loss / len(predictions)
25class MS_TCN_Loss(nn.Module): 26 """The MS-TCN loss. 27 28 Crossentropy + consistency loss (MSE over predicted probabilities). 29 """ 30 31 def __init__( 32 self, 33 num_classes: int, 34 weights: Iterable = None, 35 exclusive: bool = True, 36 ignore_index: int = -100, 37 focal: bool = False, 38 gamma: float = 1, 39 alpha: float = 0.15, 40 hard_negative_weight: float = 1, 41 ) -> None: 42 """Initialize the loss. 43 44 Parameters 45 ---------- 46 num_classes : int 47 number of classes 48 weights : iterable, optional 49 class-wise cross-entropy weights 50 exclusive : bool, default True 51 True if single-label classification is used 52 ignore_index : int, default -100 53 the elements where target is equal to ignore_index will be ignored by cross-entropy 54 focal : bool, default False 55 if True, instead of regular cross-entropy the focal loss will be used 56 gamma : float, default 1 57 the gamma parameter of the focal loss 58 alpha : float, default 0.15 59 the weight of the consistency loss 60 hard_negative_weight : float, default 1 61 the weight assigned to the hard negative frames 62 63 """ 64 super().__init__() 65 self.weights = weights 66 self.num_classes = int(num_classes) 67 self.ignore_index = ignore_index 68 self.gamma = gamma 69 self.focal = focal 70 self.alpha = alpha 71 self.exclusive = exclusive 72 self.neg_weight = hard_negative_weight 73 if exclusive: 74 self.log_nl = lambda x: F.log_softmax(x, dim=1) 75 else: 76 self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7) 77 self.mse = nn.MSELoss(reduction="none") 78 if self.weights is not None: 79 self.need_init = True 80 else: 81 self.need_init = False 82 self._init_ce() 83 84 def _init_ce(self) -> None: 85 """Initialize cross-entropy function.""" 86 if self.exclusive: 87 if self.focal: 88 self.ce = nn.CrossEntropyLoss( 89 ignore_index=self.ignore_index, 90 weight=self.weights, 91 reduction="none", 92 ) 93 else: 94 self.ce = nn.CrossEntropyLoss( 95 ignore_index=self.ignore_index, weight=self.weights 96 ) 97 else: 98 self.ce = nn.BCEWithLogitsLoss(reduction="none") 99 100 def _init_weights(self, device: str) -> None: 101 """Initialize the weights vector and the cross-entropy function (after the device is known).""" 102 if self.exclusive: 103 self.weights = torch.tensor(self.weights, device=device, dtype=torch.float) 104 else: 105 self.weights = { 106 k: torch.tensor(v, device=device, dtype=torch.float) 107 .unsqueeze(0) 108 .unsqueeze(-1) 109 for k, v in self.weights.items() 110 } 111 self._init_ce() 112 self.need_init = False 113 114 def _ce_loss(self, p: torch.Tensor, t: torch.Tensor) -> float: 115 """Apply cross-entropy loss.""" 116 if self.exclusive: 117 t = t.long() 118 p = p.transpose(2, 1).contiguous().view(-1, self.num_classes) 119 t = t.view(-1) 120 mask = t != self.ignore_index 121 if torch.sum(mask) == 0: 122 return 0 123 if self.focal: 124 pr = F.softmax(p[mask], dim=1) 125 f = (1 - pr[range(torch.sum(mask)), t[mask]]) ** self.gamma 126 f = (1 - pr[range(torch.sum(mask)), t[mask].long()]) ** self.gamma 127 loss = (f * self.ce(p, t)[mask]).mean() 128 return loss 129 else: 130 loss = self.ce(p, t) 131 return loss 132 else: 133 if self.weights is not None: 134 weight0 = self.weights[0] 135 weight1 = self.weights[1] 136 else: 137 weight0 = 1 138 weight1 = 1 139 neg_mask = t == 2 140 target = copy(t) 141 target[neg_mask] = 0 142 loss = self.ce(p, target) 143 loss[t == 1] = (loss * weight1)[t == 1] 144 loss[t == 0] = (loss * weight0)[t == 0] 145 if self.neg_weight > 1: 146 loss[neg_mask] = (loss * self.neg_weight)[neg_mask] 147 elif self.neg_weight < 1: 148 inv_neg_mask = (~neg_mask) * (target == 0) 149 loss[inv_neg_mask] = (loss * self.neg_weight)[inv_neg_mask] 150 if self.focal: 151 pr = torch.sigmoid(p) 152 factor = target * ((1 - pr) ** self.gamma) + (1 - target) * ( 153 pr**self.gamma 154 ) 155 loss = loss * factor 156 loss = loss[target != self.ignore_index] 157 return loss.mean() if loss.size()[-1] != 0 else 0 158 159 def consistency_loss(self, p: torch.Tensor) -> float: 160 """Apply consistency loss.""" 161 mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1])) 162 clamp = torch.clamp(mse, min=0, max=16) 163 return torch.mean(clamp) 164 165 def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float: 166 """Compute the loss. 167 168 Parameters 169 ---------- 170 predictions : torch.Tensor 171 a tensor of shape (#batch, #classes, #frames) 172 target : torch.Tensor 173 a tensor of shape (#batch, #classes, #frames) or (#batch, #frames) 174 175 Returns 176 ------- 177 loss : float 178 the loss value 179 180 """ 181 if self.need_init: 182 if isinstance(predictions, dict): 183 device = predictions["device"] 184 else: 185 device = predictions.device 186 self._init_weights(device) 187 loss = 0 188 if len(predictions.shape) == 4: 189 for p in predictions: 190 loss += self._ce_loss(p, target) 191 loss += self.alpha * self.consistency_loss(p) 192 else: 193 loss += self._ce_loss(predictions, target) 194 loss += self.alpha * self.consistency_loss(predictions) 195 return loss / len(predictions)
The MS-TCN loss.
Crossentropy + consistency loss (MSE over predicted probabilities).
31 def __init__( 32 self, 33 num_classes: int, 34 weights: Iterable = None, 35 exclusive: bool = True, 36 ignore_index: int = -100, 37 focal: bool = False, 38 gamma: float = 1, 39 alpha: float = 0.15, 40 hard_negative_weight: float = 1, 41 ) -> None: 42 """Initialize the loss. 43 44 Parameters 45 ---------- 46 num_classes : int 47 number of classes 48 weights : iterable, optional 49 class-wise cross-entropy weights 50 exclusive : bool, default True 51 True if single-label classification is used 52 ignore_index : int, default -100 53 the elements where target is equal to ignore_index will be ignored by cross-entropy 54 focal : bool, default False 55 if True, instead of regular cross-entropy the focal loss will be used 56 gamma : float, default 1 57 the gamma parameter of the focal loss 58 alpha : float, default 0.15 59 the weight of the consistency loss 60 hard_negative_weight : float, default 1 61 the weight assigned to the hard negative frames 62 63 """ 64 super().__init__() 65 self.weights = weights 66 self.num_classes = int(num_classes) 67 self.ignore_index = ignore_index 68 self.gamma = gamma 69 self.focal = focal 70 self.alpha = alpha 71 self.exclusive = exclusive 72 self.neg_weight = hard_negative_weight 73 if exclusive: 74 self.log_nl = lambda x: F.log_softmax(x, dim=1) 75 else: 76 self.log_nl = lambda x: torch.log(torch.sigmoid(x) + 1e-7) 77 self.mse = nn.MSELoss(reduction="none") 78 if self.weights is not None: 79 self.need_init = True 80 else: 81 self.need_init = False 82 self._init_ce()
Initialize the loss.
Parameters
num_classes : int number of classes weights : iterable, optional class-wise cross-entropy weights exclusive : bool, default True True if single-label classification is used ignore_index : int, default -100 the elements where target is equal to ignore_index will be ignored by cross-entropy focal : bool, default False if True, instead of regular cross-entropy the focal loss will be used gamma : float, default 1 the gamma parameter of the focal loss alpha : float, default 0.15 the weight of the consistency loss hard_negative_weight : float, default 1 the weight assigned to the hard negative frames
159 def consistency_loss(self, p: torch.Tensor) -> float: 160 """Apply consistency loss.""" 161 mse = self.mse(self.log_nl(p[:, :, 1:]), self.log_nl(p.detach()[:, :, :-1])) 162 clamp = torch.clamp(mse, min=0, max=16) 163 return torch.mean(clamp)
Apply consistency loss.
165 def forward(self, predictions: torch.Tensor, target: torch.Tensor) -> float: 166 """Compute the loss. 167 168 Parameters 169 ---------- 170 predictions : torch.Tensor 171 a tensor of shape (#batch, #classes, #frames) 172 target : torch.Tensor 173 a tensor of shape (#batch, #classes, #frames) or (#batch, #frames) 174 175 Returns 176 ------- 177 loss : float 178 the loss value 179 180 """ 181 if self.need_init: 182 if isinstance(predictions, dict): 183 device = predictions["device"] 184 else: 185 device = predictions.device 186 self._init_weights(device) 187 loss = 0 188 if len(predictions.shape) == 4: 189 for p in predictions: 190 loss += self._ce_loss(p, target) 191 loss += self.alpha * self.consistency_loss(p) 192 else: 193 loss += self._ce_loss(predictions, target) 194 loss += self.alpha * self.consistency_loss(predictions) 195 return loss / len(predictions)
Compute the loss.
Parameters
predictions : torch.Tensor a tensor of shape (#batch, #classes, #frames) target : torch.Tensor a tensor of shape (#batch, #classes, #frames) or (#batch, #frames)
Returns
loss : float the loss value