dlc2action.loss.mse
The mean squared error loss.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""The mean squared error loss.""" 8 9import torch 10from torch import nn 11 12 13class MSE(nn.Module): 14 """Mean square error with ignore_index parameter.""" 15 16 def __init__(self, ignore_index=-100): 17 """Initialize the loss. 18 19 Parameters 20 ---------- 21 ignore_index : int 22 the elements where target is equal to ignore_index will be ignored 23 24 """ 25 super().__init__() 26 self.ignore_index = ignore_index 27 28 def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 29 """Compute the loss. 30 31 Parameters 32 ---------- 33 predicted, target : torch.Tensor 34 a tensor of any shape 35 36 Returns 37 ------- 38 loss : float 39 the loss value 40 41 """ 42 mask = target != self.ignore_index 43 return torch.mean((predicted[mask] - target[mask]) ** 2)
class
MSE(torch.nn.modules.module.Module):
14class MSE(nn.Module): 15 """Mean square error with ignore_index parameter.""" 16 17 def __init__(self, ignore_index=-100): 18 """Initialize the loss. 19 20 Parameters 21 ---------- 22 ignore_index : int 23 the elements where target is equal to ignore_index will be ignored 24 25 """ 26 super().__init__() 27 self.ignore_index = ignore_index 28 29 def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 30 """Compute the loss. 31 32 Parameters 33 ---------- 34 predicted, target : torch.Tensor 35 a tensor of any shape 36 37 Returns 38 ------- 39 loss : float 40 the loss value 41 42 """ 43 mask = target != self.ignore_index 44 return torch.mean((predicted[mask] - target[mask]) ** 2)
Mean square error with ignore_index parameter.
MSE(ignore_index=-100)
17 def __init__(self, ignore_index=-100): 18 """Initialize the loss. 19 20 Parameters 21 ---------- 22 ignore_index : int 23 the elements where target is equal to ignore_index will be ignored 24 25 """ 26 super().__init__() 27 self.ignore_index = ignore_index
Initialize the loss.
Parameters
ignore_index : int the elements where target is equal to ignore_index will be ignored
def
forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
29 def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 30 """Compute the loss. 31 32 Parameters 33 ---------- 34 predicted, target : torch.Tensor 35 a tensor of any shape 36 37 Returns 38 ------- 39 loss : float 40 the loss value 41 42 """ 43 mask = target != self.ignore_index 44 return torch.mean((predicted[mask] - target[mask]) ** 2)
Compute the loss.
Parameters
predicted, target : torch.Tensor a tensor of any shape
Returns
loss : float the loss value