dlc2action.loss.mse
The mean squared error loss
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7The mean squared error loss 8""" 9 10from torch import nn 11import torch 12 13 14class _MSE(nn.Module): 15 """ 16 Mean square error with ignore_index parameter 17 """ 18 19 def __init__(self, ignore_index=-100): 20 """ 21 Parameters 22 ---------- 23 ignore_index : int 24 the elements where target is equal to ignore_index will be ignored 25 """ 26 27 super().__init__() 28 self.ignore_index = ignore_index 29 30 def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 31 """ 32 Compute the loss 33 34 Parameters 35 ---------- 36 predicted, target : torch.Tensor 37 a tensor of any shape 38 39 Returns 40 ------- 41 loss : float 42 the loss value 43 """ 44 45 mask = target != self.ignore_index 46 return torch.mean((predicted[mask] - target[mask]) ** 2)