dlc2action.loss.asymmetric_loss

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7import torch
  8import torch.nn as nn
  9
 10"""
 11In development
 12"""
 13torch.manual_seed(0)
 14
 15
 16class AsymmetricLoss(nn.Module):
 17    def __init__(
 18        self,
 19        gamma_neg=4,
 20        gamma_pos=1,
 21        clip=0.05,
 22        eps=1e-8,
 23        disable_torch_grad_focal_loss=True,
 24    ):
 25        super(AsymmetricLoss, self).__init__()
 26
 27        self.gamma_neg = gamma_neg
 28        self.gamma_pos = gamma_pos
 29        self.clip = clip
 30        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
 31        self.eps = eps
 32
 33    def forward(self, x, y):
 34        """ "
 35        Parameters
 36        ----------
 37        x: input logits
 38        y: targets (multi-label binarized vector)
 39        """
 40
 41        # Calculating Probabilities
 42        x_sigmoid = torch.sigmoid(x)
 43        xs_pos = x_sigmoid
 44        xs_neg = 1 - x_sigmoid
 45
 46        # Asymmetric Clipping
 47        if self.clip is not None and self.clip > 0:
 48            xs_neg = (xs_neg + self.clip).clamp(max=1)
 49
 50        # Basic CE calculation
 51        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
 52        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
 53        loss = los_pos + los_neg
 54
 55        # Asymmetric Focusing
 56        if self.gamma_neg > 0 or self.gamma_pos > 0:
 57            if self.disable_torch_grad_focal_loss:
 58                torch._C.set_grad_enabled(False)
 59            pt0 = xs_pos * y
 60            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
 61            pt = pt0 + pt1
 62            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
 63            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
 64            if self.disable_torch_grad_focal_loss:
 65                torch._C.set_grad_enabled(True)
 66            loss *= one_sided_w
 67
 68        return -loss.sum()
 69
 70
 71class AsymmetricLossOptimized(nn.Module):
 72    """Notice - optimized version, minimizes memory allocation and gpu uploading,
 73    favors inplace operations"""
 74
 75    def __init__(
 76        self,
 77        gamma_neg=4,
 78        gamma_pos=1,
 79        clip=0.05,
 80        eps=1e-8,
 81        disable_torch_grad_focal_loss=False,
 82    ):
 83        super(AsymmetricLossOptimized, self).__init__()
 84
 85        self.gamma_neg = gamma_neg
 86        self.gamma_pos = gamma_pos
 87        self.clip = clip
 88        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
 89        self.eps = eps
 90
 91        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
 92        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = (
 93            self.asymmetric_w
 94        ) = self.loss = None
 95
 96    def forward(self, x, y):
 97        """ "
 98        Parameters
 99        ----------
100        x: input logits
101        y: targets (multi-label binarized vector)
102        """
103
104        self.targets = y
105        self.anti_targets = 1 - y
106
107        # Calculating Probabilities
108        self.xs_pos = torch.sigmoid(x)
109        self.xs_neg = 1.0 - self.xs_pos
110
111        # Asymmetric Clipping
112        if self.clip is not None and self.clip > 0:
113            self.xs_neg.add_(self.clip).clamp_(max=1)
114
115        # Basic CE calculation
116        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
117        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
118
119        # Asymmetric Focusing
120        if self.gamma_neg > 0 or self.gamma_pos > 0:
121            if self.disable_torch_grad_focal_loss:
122                torch._C.set_grad_enabled(False)
123            self.xs_pos = self.xs_pos * self.targets
124            self.xs_neg = self.xs_neg * self.anti_targets
125            self.asymmetric_w = torch.pow(
126                1 - self.xs_pos - self.xs_neg,
127                self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets,
128            )
129            if self.disable_torch_grad_focal_loss:
130                torch._C.set_grad_enabled(True)
131            self.loss *= self.asymmetric_w
132
133        return -self.loss.sum()
134
135
136class ASLSingleLabel(nn.Module):
137    def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction="mean"):
138        super(ASLSingleLabel, self).__init__()
139
140        self.eps = eps
141        self.logsoftmax = nn.LogSoftmax(dim=-1)
142        self.targets_classes = []  # prevent gpu repeated memory allocation
143        self.gamma_pos = gamma_pos
144        self.gamma_neg = gamma_neg
145        self.reduction = reduction
146
147    def forward(self, inputs, target, reduction=None):
148        num_classes = inputs.size()[-1]
149        log_preds = self.logsoftmax(inputs)
150        self.targets_classes = torch.zeros_like(inputs).scatter_(
151            1, target.long().unsqueeze(1), 1
152        )
153
154        # ASL weights
155        targets = self.targets_classes
156        anti_targets = 1 - targets
157        xs_pos = torch.exp(log_preds)
158        xs_neg = 1 - xs_pos
159        xs_pos = xs_pos * targets
160        xs_neg = xs_neg * anti_targets
161        asymmetric_w = torch.pow(
162            1 - xs_pos - xs_neg,
163            self.gamma_pos * targets + self.gamma_neg * anti_targets,
164        )
165        log_preds = log_preds * asymmetric_w
166
167        if self.eps > 0:  # label smoothing
168            self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes)
169
170        # loss calculation
171        loss = -self.targets_classes.mul(log_preds)
172
173        loss = loss.sum(dim=-1)
174        if self.reduction == "mean":
175            loss = loss.mean()
176
177        return loss
class AsymmetricLoss(torch.nn.modules.module.Module):
17class AsymmetricLoss(nn.Module):
18    def __init__(
19        self,
20        gamma_neg=4,
21        gamma_pos=1,
22        clip=0.05,
23        eps=1e-8,
24        disable_torch_grad_focal_loss=True,
25    ):
26        super(AsymmetricLoss, self).__init__()
27
28        self.gamma_neg = gamma_neg
29        self.gamma_pos = gamma_pos
30        self.clip = clip
31        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
32        self.eps = eps
33
34    def forward(self, x, y):
35        """ "
36        Parameters
37        ----------
38        x: input logits
39        y: targets (multi-label binarized vector)
40        """
41
42        # Calculating Probabilities
43        x_sigmoid = torch.sigmoid(x)
44        xs_pos = x_sigmoid
45        xs_neg = 1 - x_sigmoid
46
47        # Asymmetric Clipping
48        if self.clip is not None and self.clip > 0:
49            xs_neg = (xs_neg + self.clip).clamp(max=1)
50
51        # Basic CE calculation
52        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
53        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
54        loss = los_pos + los_neg
55
56        # Asymmetric Focusing
57        if self.gamma_neg > 0 or self.gamma_pos > 0:
58            if self.disable_torch_grad_focal_loss:
59                torch._C.set_grad_enabled(False)
60            pt0 = xs_pos * y
61            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
62            pt = pt0 + pt1
63            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
64            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
65            if self.disable_torch_grad_focal_loss:
66                torch._C.set_grad_enabled(True)
67            loss *= one_sided_w
68
69        return -loss.sum()

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AsymmetricLoss( gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-08, disable_torch_grad_focal_loss=True)
18    def __init__(
19        self,
20        gamma_neg=4,
21        gamma_pos=1,
22        clip=0.05,
23        eps=1e-8,
24        disable_torch_grad_focal_loss=True,
25    ):
26        super(AsymmetricLoss, self).__init__()
27
28        self.gamma_neg = gamma_neg
29        self.gamma_pos = gamma_pos
30        self.clip = clip
31        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
32        self.eps = eps

Initialize internal Module state, shared by both nn.Module and ScriptModule.

gamma_neg
gamma_pos
clip
disable_torch_grad_focal_loss
eps
def forward(self, x, y):
34    def forward(self, x, y):
35        """ "
36        Parameters
37        ----------
38        x: input logits
39        y: targets (multi-label binarized vector)
40        """
41
42        # Calculating Probabilities
43        x_sigmoid = torch.sigmoid(x)
44        xs_pos = x_sigmoid
45        xs_neg = 1 - x_sigmoid
46
47        # Asymmetric Clipping
48        if self.clip is not None and self.clip > 0:
49            xs_neg = (xs_neg + self.clip).clamp(max=1)
50
51        # Basic CE calculation
52        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
53        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
54        loss = los_pos + los_neg
55
56        # Asymmetric Focusing
57        if self.gamma_neg > 0 or self.gamma_pos > 0:
58            if self.disable_torch_grad_focal_loss:
59                torch._C.set_grad_enabled(False)
60            pt0 = xs_pos * y
61            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
62            pt = pt0 + pt1
63            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
64            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
65            if self.disable_torch_grad_focal_loss:
66                torch._C.set_grad_enabled(True)
67            loss *= one_sided_w
68
69        return -loss.sum()

"

Parameters

x: input logits y: targets (multi-label binarized vector)

class AsymmetricLossOptimized(torch.nn.modules.module.Module):
 72class AsymmetricLossOptimized(nn.Module):
 73    """Notice - optimized version, minimizes memory allocation and gpu uploading,
 74    favors inplace operations"""
 75
 76    def __init__(
 77        self,
 78        gamma_neg=4,
 79        gamma_pos=1,
 80        clip=0.05,
 81        eps=1e-8,
 82        disable_torch_grad_focal_loss=False,
 83    ):
 84        super(AsymmetricLossOptimized, self).__init__()
 85
 86        self.gamma_neg = gamma_neg
 87        self.gamma_pos = gamma_pos
 88        self.clip = clip
 89        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
 90        self.eps = eps
 91
 92        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
 93        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = (
 94            self.asymmetric_w
 95        ) = self.loss = None
 96
 97    def forward(self, x, y):
 98        """ "
 99        Parameters
100        ----------
101        x: input logits
102        y: targets (multi-label binarized vector)
103        """
104
105        self.targets = y
106        self.anti_targets = 1 - y
107
108        # Calculating Probabilities
109        self.xs_pos = torch.sigmoid(x)
110        self.xs_neg = 1.0 - self.xs_pos
111
112        # Asymmetric Clipping
113        if self.clip is not None and self.clip > 0:
114            self.xs_neg.add_(self.clip).clamp_(max=1)
115
116        # Basic CE calculation
117        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
118        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
119
120        # Asymmetric Focusing
121        if self.gamma_neg > 0 or self.gamma_pos > 0:
122            if self.disable_torch_grad_focal_loss:
123                torch._C.set_grad_enabled(False)
124            self.xs_pos = self.xs_pos * self.targets
125            self.xs_neg = self.xs_neg * self.anti_targets
126            self.asymmetric_w = torch.pow(
127                1 - self.xs_pos - self.xs_neg,
128                self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets,
129            )
130            if self.disable_torch_grad_focal_loss:
131                torch._C.set_grad_enabled(True)
132            self.loss *= self.asymmetric_w
133
134        return -self.loss.sum()

Notice - optimized version, minimizes memory allocation and gpu uploading, favors inplace operations

AsymmetricLossOptimized( gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-08, disable_torch_grad_focal_loss=False)
76    def __init__(
77        self,
78        gamma_neg=4,
79        gamma_pos=1,
80        clip=0.05,
81        eps=1e-8,
82        disable_torch_grad_focal_loss=False,
83    ):
84        super(AsymmetricLossOptimized, self).__init__()
85
86        self.gamma_neg = gamma_neg
87        self.gamma_pos = gamma_pos
88        self.clip = clip
89        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
90        self.eps = eps
91
92        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
93        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = (
94            self.asymmetric_w
95        ) = self.loss = None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

gamma_neg
gamma_pos
clip
disable_torch_grad_focal_loss
eps
def forward(self, x, y):
 97    def forward(self, x, y):
 98        """ "
 99        Parameters
100        ----------
101        x: input logits
102        y: targets (multi-label binarized vector)
103        """
104
105        self.targets = y
106        self.anti_targets = 1 - y
107
108        # Calculating Probabilities
109        self.xs_pos = torch.sigmoid(x)
110        self.xs_neg = 1.0 - self.xs_pos
111
112        # Asymmetric Clipping
113        if self.clip is not None and self.clip > 0:
114            self.xs_neg.add_(self.clip).clamp_(max=1)
115
116        # Basic CE calculation
117        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
118        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
119
120        # Asymmetric Focusing
121        if self.gamma_neg > 0 or self.gamma_pos > 0:
122            if self.disable_torch_grad_focal_loss:
123                torch._C.set_grad_enabled(False)
124            self.xs_pos = self.xs_pos * self.targets
125            self.xs_neg = self.xs_neg * self.anti_targets
126            self.asymmetric_w = torch.pow(
127                1 - self.xs_pos - self.xs_neg,
128                self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets,
129            )
130            if self.disable_torch_grad_focal_loss:
131                torch._C.set_grad_enabled(True)
132            self.loss *= self.asymmetric_w
133
134        return -self.loss.sum()

"

Parameters

x: input logits y: targets (multi-label binarized vector)

class ASLSingleLabel(torch.nn.modules.module.Module):
137class ASLSingleLabel(nn.Module):
138    def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction="mean"):
139        super(ASLSingleLabel, self).__init__()
140
141        self.eps = eps
142        self.logsoftmax = nn.LogSoftmax(dim=-1)
143        self.targets_classes = []  # prevent gpu repeated memory allocation
144        self.gamma_pos = gamma_pos
145        self.gamma_neg = gamma_neg
146        self.reduction = reduction
147
148    def forward(self, inputs, target, reduction=None):
149        num_classes = inputs.size()[-1]
150        log_preds = self.logsoftmax(inputs)
151        self.targets_classes = torch.zeros_like(inputs).scatter_(
152            1, target.long().unsqueeze(1), 1
153        )
154
155        # ASL weights
156        targets = self.targets_classes
157        anti_targets = 1 - targets
158        xs_pos = torch.exp(log_preds)
159        xs_neg = 1 - xs_pos
160        xs_pos = xs_pos * targets
161        xs_neg = xs_neg * anti_targets
162        asymmetric_w = torch.pow(
163            1 - xs_pos - xs_neg,
164            self.gamma_pos * targets + self.gamma_neg * anti_targets,
165        )
166        log_preds = log_preds * asymmetric_w
167
168        if self.eps > 0:  # label smoothing
169            self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes)
170
171        # loss calculation
172        loss = -self.targets_classes.mul(log_preds)
173
174        loss = loss.sum(dim=-1)
175        if self.reduction == "mean":
176            loss = loss.mean()
177
178        return loss

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ASLSingleLabel(gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean')
138    def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction="mean"):
139        super(ASLSingleLabel, self).__init__()
140
141        self.eps = eps
142        self.logsoftmax = nn.LogSoftmax(dim=-1)
143        self.targets_classes = []  # prevent gpu repeated memory allocation
144        self.gamma_pos = gamma_pos
145        self.gamma_neg = gamma_neg
146        self.reduction = reduction

Initialize internal Module state, shared by both nn.Module and ScriptModule.

eps
logsoftmax
targets_classes
gamma_pos
gamma_neg
reduction
def forward(self, inputs, target, reduction=None):
148    def forward(self, inputs, target, reduction=None):
149        num_classes = inputs.size()[-1]
150        log_preds = self.logsoftmax(inputs)
151        self.targets_classes = torch.zeros_like(inputs).scatter_(
152            1, target.long().unsqueeze(1), 1
153        )
154
155        # ASL weights
156        targets = self.targets_classes
157        anti_targets = 1 - targets
158        xs_pos = torch.exp(log_preds)
159        xs_neg = 1 - xs_pos
160        xs_pos = xs_pos * targets
161        xs_neg = xs_neg * anti_targets
162        asymmetric_w = torch.pow(
163            1 - xs_pos - xs_neg,
164            self.gamma_pos * targets + self.gamma_neg * anti_targets,
165        )
166        log_preds = log_preds * asymmetric_w
167
168        if self.eps > 0:  # label smoothing
169            self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes)
170
171        # loss calculation
172        loss = -self.targets_classes.mul(log_preds)
173
174        loss = loss.sum(dim=-1)
175        if self.reduction == "mean":
176            loss = loss.mean()
177
178        return loss

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.