dlc2action.loss.asymmetric_loss
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7import torch 8import torch.nn as nn 9 10""" 11In development 12""" 13torch.manual_seed(0) 14 15 16class AsymmetricLoss(nn.Module): 17 def __init__( 18 self, 19 gamma_neg=4, 20 gamma_pos=1, 21 clip=0.05, 22 eps=1e-8, 23 disable_torch_grad_focal_loss=True, 24 ): 25 super(AsymmetricLoss, self).__init__() 26 27 self.gamma_neg = gamma_neg 28 self.gamma_pos = gamma_pos 29 self.clip = clip 30 self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 31 self.eps = eps 32 33 def forward(self, x, y): 34 """ " 35 Parameters 36 ---------- 37 x: input logits 38 y: targets (multi-label binarized vector) 39 """ 40 41 # Calculating Probabilities 42 x_sigmoid = torch.sigmoid(x) 43 xs_pos = x_sigmoid 44 xs_neg = 1 - x_sigmoid 45 46 # Asymmetric Clipping 47 if self.clip is not None and self.clip > 0: 48 xs_neg = (xs_neg + self.clip).clamp(max=1) 49 50 # Basic CE calculation 51 los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 52 los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 53 loss = los_pos + los_neg 54 55 # Asymmetric Focusing 56 if self.gamma_neg > 0 or self.gamma_pos > 0: 57 if self.disable_torch_grad_focal_loss: 58 torch._C.set_grad_enabled(False) 59 pt0 = xs_pos * y 60 pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 61 pt = pt0 + pt1 62 one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 63 one_sided_w = torch.pow(1 - pt, one_sided_gamma) 64 if self.disable_torch_grad_focal_loss: 65 torch._C.set_grad_enabled(True) 66 loss *= one_sided_w 67 68 return -loss.sum() 69 70 71class AsymmetricLossOptimized(nn.Module): 72 """Notice - optimized version, minimizes memory allocation and gpu uploading, 73 favors inplace operations""" 74 75 def __init__( 76 self, 77 gamma_neg=4, 78 gamma_pos=1, 79 clip=0.05, 80 eps=1e-8, 81 disable_torch_grad_focal_loss=False, 82 ): 83 super(AsymmetricLossOptimized, self).__init__() 84 85 self.gamma_neg = gamma_neg 86 self.gamma_pos = gamma_pos 87 self.clip = clip 88 self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 89 self.eps = eps 90 91 # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations 92 self.targets = self.anti_targets = self.xs_pos = self.xs_neg = ( 93 self.asymmetric_w 94 ) = self.loss = None 95 96 def forward(self, x, y): 97 """ " 98 Parameters 99 ---------- 100 x: input logits 101 y: targets (multi-label binarized vector) 102 """ 103 104 self.targets = y 105 self.anti_targets = 1 - y 106 107 # Calculating Probabilities 108 self.xs_pos = torch.sigmoid(x) 109 self.xs_neg = 1.0 - self.xs_pos 110 111 # Asymmetric Clipping 112 if self.clip is not None and self.clip > 0: 113 self.xs_neg.add_(self.clip).clamp_(max=1) 114 115 # Basic CE calculation 116 self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) 117 self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) 118 119 # Asymmetric Focusing 120 if self.gamma_neg > 0 or self.gamma_pos > 0: 121 if self.disable_torch_grad_focal_loss: 122 torch._C.set_grad_enabled(False) 123 self.xs_pos = self.xs_pos * self.targets 124 self.xs_neg = self.xs_neg * self.anti_targets 125 self.asymmetric_w = torch.pow( 126 1 - self.xs_pos - self.xs_neg, 127 self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets, 128 ) 129 if self.disable_torch_grad_focal_loss: 130 torch._C.set_grad_enabled(True) 131 self.loss *= self.asymmetric_w 132 133 return -self.loss.sum() 134 135 136class ASLSingleLabel(nn.Module): 137 def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction="mean"): 138 super(ASLSingleLabel, self).__init__() 139 140 self.eps = eps 141 self.logsoftmax = nn.LogSoftmax(dim=-1) 142 self.targets_classes = [] # prevent gpu repeated memory allocation 143 self.gamma_pos = gamma_pos 144 self.gamma_neg = gamma_neg 145 self.reduction = reduction 146 147 def forward(self, inputs, target, reduction=None): 148 num_classes = inputs.size()[-1] 149 log_preds = self.logsoftmax(inputs) 150 self.targets_classes = torch.zeros_like(inputs).scatter_( 151 1, target.long().unsqueeze(1), 1 152 ) 153 154 # ASL weights 155 targets = self.targets_classes 156 anti_targets = 1 - targets 157 xs_pos = torch.exp(log_preds) 158 xs_neg = 1 - xs_pos 159 xs_pos = xs_pos * targets 160 xs_neg = xs_neg * anti_targets 161 asymmetric_w = torch.pow( 162 1 - xs_pos - xs_neg, 163 self.gamma_pos * targets + self.gamma_neg * anti_targets, 164 ) 165 log_preds = log_preds * asymmetric_w 166 167 if self.eps > 0: # label smoothing 168 self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) 169 170 # loss calculation 171 loss = -self.targets_classes.mul(log_preds) 172 173 loss = loss.sum(dim=-1) 174 if self.reduction == "mean": 175 loss = loss.mean() 176 177 return loss
17class AsymmetricLoss(nn.Module): 18 def __init__( 19 self, 20 gamma_neg=4, 21 gamma_pos=1, 22 clip=0.05, 23 eps=1e-8, 24 disable_torch_grad_focal_loss=True, 25 ): 26 super(AsymmetricLoss, self).__init__() 27 28 self.gamma_neg = gamma_neg 29 self.gamma_pos = gamma_pos 30 self.clip = clip 31 self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 32 self.eps = eps 33 34 def forward(self, x, y): 35 """ " 36 Parameters 37 ---------- 38 x: input logits 39 y: targets (multi-label binarized vector) 40 """ 41 42 # Calculating Probabilities 43 x_sigmoid = torch.sigmoid(x) 44 xs_pos = x_sigmoid 45 xs_neg = 1 - x_sigmoid 46 47 # Asymmetric Clipping 48 if self.clip is not None and self.clip > 0: 49 xs_neg = (xs_neg + self.clip).clamp(max=1) 50 51 # Basic CE calculation 52 los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 53 los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 54 loss = los_pos + los_neg 55 56 # Asymmetric Focusing 57 if self.gamma_neg > 0 or self.gamma_pos > 0: 58 if self.disable_torch_grad_focal_loss: 59 torch._C.set_grad_enabled(False) 60 pt0 = xs_pos * y 61 pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 62 pt = pt0 + pt1 63 one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 64 one_sided_w = torch.pow(1 - pt, one_sided_gamma) 65 if self.disable_torch_grad_focal_loss: 66 torch._C.set_grad_enabled(True) 67 loss *= one_sided_w 68 69 return -loss.sum()
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
18 def __init__( 19 self, 20 gamma_neg=4, 21 gamma_pos=1, 22 clip=0.05, 23 eps=1e-8, 24 disable_torch_grad_focal_loss=True, 25 ): 26 super(AsymmetricLoss, self).__init__() 27 28 self.gamma_neg = gamma_neg 29 self.gamma_pos = gamma_pos 30 self.clip = clip 31 self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 32 self.eps = eps
Initialize internal Module state, shared by both nn.Module and ScriptModule.
34 def forward(self, x, y): 35 """ " 36 Parameters 37 ---------- 38 x: input logits 39 y: targets (multi-label binarized vector) 40 """ 41 42 # Calculating Probabilities 43 x_sigmoid = torch.sigmoid(x) 44 xs_pos = x_sigmoid 45 xs_neg = 1 - x_sigmoid 46 47 # Asymmetric Clipping 48 if self.clip is not None and self.clip > 0: 49 xs_neg = (xs_neg + self.clip).clamp(max=1) 50 51 # Basic CE calculation 52 los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 53 los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 54 loss = los_pos + los_neg 55 56 # Asymmetric Focusing 57 if self.gamma_neg > 0 or self.gamma_pos > 0: 58 if self.disable_torch_grad_focal_loss: 59 torch._C.set_grad_enabled(False) 60 pt0 = xs_pos * y 61 pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 62 pt = pt0 + pt1 63 one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 64 one_sided_w = torch.pow(1 - pt, one_sided_gamma) 65 if self.disable_torch_grad_focal_loss: 66 torch._C.set_grad_enabled(True) 67 loss *= one_sided_w 68 69 return -loss.sum()
"
Parameters
x: input logits y: targets (multi-label binarized vector)
72class AsymmetricLossOptimized(nn.Module): 73 """Notice - optimized version, minimizes memory allocation and gpu uploading, 74 favors inplace operations""" 75 76 def __init__( 77 self, 78 gamma_neg=4, 79 gamma_pos=1, 80 clip=0.05, 81 eps=1e-8, 82 disable_torch_grad_focal_loss=False, 83 ): 84 super(AsymmetricLossOptimized, self).__init__() 85 86 self.gamma_neg = gamma_neg 87 self.gamma_pos = gamma_pos 88 self.clip = clip 89 self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 90 self.eps = eps 91 92 # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations 93 self.targets = self.anti_targets = self.xs_pos = self.xs_neg = ( 94 self.asymmetric_w 95 ) = self.loss = None 96 97 def forward(self, x, y): 98 """ " 99 Parameters 100 ---------- 101 x: input logits 102 y: targets (multi-label binarized vector) 103 """ 104 105 self.targets = y 106 self.anti_targets = 1 - y 107 108 # Calculating Probabilities 109 self.xs_pos = torch.sigmoid(x) 110 self.xs_neg = 1.0 - self.xs_pos 111 112 # Asymmetric Clipping 113 if self.clip is not None and self.clip > 0: 114 self.xs_neg.add_(self.clip).clamp_(max=1) 115 116 # Basic CE calculation 117 self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) 118 self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) 119 120 # Asymmetric Focusing 121 if self.gamma_neg > 0 or self.gamma_pos > 0: 122 if self.disable_torch_grad_focal_loss: 123 torch._C.set_grad_enabled(False) 124 self.xs_pos = self.xs_pos * self.targets 125 self.xs_neg = self.xs_neg * self.anti_targets 126 self.asymmetric_w = torch.pow( 127 1 - self.xs_pos - self.xs_neg, 128 self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets, 129 ) 130 if self.disable_torch_grad_focal_loss: 131 torch._C.set_grad_enabled(True) 132 self.loss *= self.asymmetric_w 133 134 return -self.loss.sum()
Notice - optimized version, minimizes memory allocation and gpu uploading, favors inplace operations
76 def __init__( 77 self, 78 gamma_neg=4, 79 gamma_pos=1, 80 clip=0.05, 81 eps=1e-8, 82 disable_torch_grad_focal_loss=False, 83 ): 84 super(AsymmetricLossOptimized, self).__init__() 85 86 self.gamma_neg = gamma_neg 87 self.gamma_pos = gamma_pos 88 self.clip = clip 89 self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 90 self.eps = eps 91 92 # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations 93 self.targets = self.anti_targets = self.xs_pos = self.xs_neg = ( 94 self.asymmetric_w 95 ) = self.loss = None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
97 def forward(self, x, y): 98 """ " 99 Parameters 100 ---------- 101 x: input logits 102 y: targets (multi-label binarized vector) 103 """ 104 105 self.targets = y 106 self.anti_targets = 1 - y 107 108 # Calculating Probabilities 109 self.xs_pos = torch.sigmoid(x) 110 self.xs_neg = 1.0 - self.xs_pos 111 112 # Asymmetric Clipping 113 if self.clip is not None and self.clip > 0: 114 self.xs_neg.add_(self.clip).clamp_(max=1) 115 116 # Basic CE calculation 117 self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) 118 self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) 119 120 # Asymmetric Focusing 121 if self.gamma_neg > 0 or self.gamma_pos > 0: 122 if self.disable_torch_grad_focal_loss: 123 torch._C.set_grad_enabled(False) 124 self.xs_pos = self.xs_pos * self.targets 125 self.xs_neg = self.xs_neg * self.anti_targets 126 self.asymmetric_w = torch.pow( 127 1 - self.xs_pos - self.xs_neg, 128 self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets, 129 ) 130 if self.disable_torch_grad_focal_loss: 131 torch._C.set_grad_enabled(True) 132 self.loss *= self.asymmetric_w 133 134 return -self.loss.sum()
"
Parameters
x: input logits y: targets (multi-label binarized vector)
137class ASLSingleLabel(nn.Module): 138 def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction="mean"): 139 super(ASLSingleLabel, self).__init__() 140 141 self.eps = eps 142 self.logsoftmax = nn.LogSoftmax(dim=-1) 143 self.targets_classes = [] # prevent gpu repeated memory allocation 144 self.gamma_pos = gamma_pos 145 self.gamma_neg = gamma_neg 146 self.reduction = reduction 147 148 def forward(self, inputs, target, reduction=None): 149 num_classes = inputs.size()[-1] 150 log_preds = self.logsoftmax(inputs) 151 self.targets_classes = torch.zeros_like(inputs).scatter_( 152 1, target.long().unsqueeze(1), 1 153 ) 154 155 # ASL weights 156 targets = self.targets_classes 157 anti_targets = 1 - targets 158 xs_pos = torch.exp(log_preds) 159 xs_neg = 1 - xs_pos 160 xs_pos = xs_pos * targets 161 xs_neg = xs_neg * anti_targets 162 asymmetric_w = torch.pow( 163 1 - xs_pos - xs_neg, 164 self.gamma_pos * targets + self.gamma_neg * anti_targets, 165 ) 166 log_preds = log_preds * asymmetric_w 167 168 if self.eps > 0: # label smoothing 169 self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) 170 171 # loss calculation 172 loss = -self.targets_classes.mul(log_preds) 173 174 loss = loss.sum(dim=-1) 175 if self.reduction == "mean": 176 loss = loss.mean() 177 178 return loss
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
138 def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction="mean"): 139 super(ASLSingleLabel, self).__init__() 140 141 self.eps = eps 142 self.logsoftmax = nn.LogSoftmax(dim=-1) 143 self.targets_classes = [] # prevent gpu repeated memory allocation 144 self.gamma_pos = gamma_pos 145 self.gamma_neg = gamma_neg 146 self.reduction = reduction
Initialize internal Module state, shared by both nn.Module and ScriptModule.
148 def forward(self, inputs, target, reduction=None): 149 num_classes = inputs.size()[-1] 150 log_preds = self.logsoftmax(inputs) 151 self.targets_classes = torch.zeros_like(inputs).scatter_( 152 1, target.long().unsqueeze(1), 1 153 ) 154 155 # ASL weights 156 targets = self.targets_classes 157 anti_targets = 1 - targets 158 xs_pos = torch.exp(log_preds) 159 xs_neg = 1 - xs_pos 160 xs_pos = xs_pos * targets 161 xs_neg = xs_neg * anti_targets 162 asymmetric_w = torch.pow( 163 1 - xs_pos - xs_neg, 164 self.gamma_pos * targets + self.gamma_neg * anti_targets, 165 ) 166 log_preds = log_preds * asymmetric_w 167 168 if self.eps > 0: # label smoothing 169 self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) 170 171 # loss calculation 172 loss = -self.targets_classes.mul(log_preds) 173 174 loss = loss.sum(dim=-1) 175 if self.reduction == "mean": 176 loss = loss.mean() 177 178 return loss
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.