dlc2action.loss.mse

The mean squared error loss.

 1#
 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. 
 5# A copy is included in dlc2action/LICENSE.AGPL.
 6#
 7"""The mean squared error loss."""
 8
 9import torch
10from torch import nn
11
12
13class MSE(nn.Module):
14    """Mean square error with ignore_index parameter."""
15
16    def __init__(self, ignore_index=-100):
17        """Initialize the loss.
18
19        Parameters
20        ----------
21        ignore_index : int
22            the elements where target is equal to ignore_index will be ignored
23
24        """
25        super().__init__()
26        self.ignore_index = ignore_index
27
28    def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
29        """Compute the loss.
30
31        Parameters
32        ----------
33        predicted, target : torch.Tensor
34            a tensor of any shape
35
36        Returns
37        -------
38        loss : float
39            the loss value
40
41        """
42        mask = target != self.ignore_index
43        return torch.mean((predicted[mask] - target[mask]) ** 2)
class MSE(torch.nn.modules.module.Module):
14class MSE(nn.Module):
15    """Mean square error with ignore_index parameter."""
16
17    def __init__(self, ignore_index=-100):
18        """Initialize the loss.
19
20        Parameters
21        ----------
22        ignore_index : int
23            the elements where target is equal to ignore_index will be ignored
24
25        """
26        super().__init__()
27        self.ignore_index = ignore_index
28
29    def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
30        """Compute the loss.
31
32        Parameters
33        ----------
34        predicted, target : torch.Tensor
35            a tensor of any shape
36
37        Returns
38        -------
39        loss : float
40            the loss value
41
42        """
43        mask = target != self.ignore_index
44        return torch.mean((predicted[mask] - target[mask]) ** 2)

Mean square error with ignore_index parameter.

MSE(ignore_index=-100)
17    def __init__(self, ignore_index=-100):
18        """Initialize the loss.
19
20        Parameters
21        ----------
22        ignore_index : int
23            the elements where target is equal to ignore_index will be ignored
24
25        """
26        super().__init__()
27        self.ignore_index = ignore_index

Initialize the loss.

Parameters

ignore_index : int the elements where target is equal to ignore_index will be ignored

ignore_index
def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
29    def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
30        """Compute the loss.
31
32        Parameters
33        ----------
34        predicted, target : torch.Tensor
35            a tensor of any shape
36
37        Returns
38        -------
39        loss : float
40            the loss value
41
42        """
43        mask = target != self.ignore_index
44        return torch.mean((predicted[mask] - target[mask]) ** 2)

Compute the loss.

Parameters

predicted, target : torch.Tensor a tensor of any shape

Returns

loss : float the loss value