dlc2action.loss.mse

The mean squared error loss

 1#
 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
 5#
 6"""
 7The mean squared error loss
 8"""
 9
10from torch import nn
11import torch
12
13
14class _MSE(nn.Module):
15    """
16    Mean square error with ignore_index parameter
17    """
18
19    def __init__(self, ignore_index=-100):
20        """
21        Parameters
22        ----------
23        ignore_index : int
24            the elements where target is equal to ignore_index will be ignored
25        """
26
27        super().__init__()
28        self.ignore_index = ignore_index
29
30    def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
31        """
32        Compute the loss
33
34        Parameters
35        ----------
36        predicted, target : torch.Tensor
37            a tensor of any shape
38
39        Returns
40        -------
41        loss : float
42            the loss value
43        """
44
45        mask = target != self.ignore_index
46        return torch.mean((predicted[mask] - target[mask]) ** 2)