dlc2action.metric.base_metric
Abstract parent class for all metrics.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Abstract parent class for all metrics.""" 8 9from abc import ABC, abstractmethod 10from typing import Any, Dict, Union 11 12import torch 13 14torch.manual_seed(0) 15 16 17class Metric(ABC): 18 """Base class for all metrics. 19 20 Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch. 21 If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector; 22 otherwise it should be the final class prediction 23 """ 24 25 needs_raw_data = False 26 """ 27 If `True`, `dlc2action.task.universal_task.Task` will pass raw data to the metric (only primary predict 28 function applied). 29 Otherwise it will pass a prediction for the classes. 30 """ 31 32 def __init__(self) -> None: 33 """Initialize the class.""" 34 self.reset() 35 36 @abstractmethod 37 def update( 38 self, 39 predicted: torch.Tensor, 40 target: torch.Tensor, 41 tags: torch.Tensor, 42 ) -> None: 43 """Update the intrinsic parameters (with a batch). 44 45 Parameters 46 ---------- 47 predicted : torch.Tensor 48 the main prediction tensor generated by the model 49 target : torch.Tensor 50 the corresponding main target tensor 51 tags : torch.Tensor 52 the tensor of meta tags (or `None`, if tags are not given) 53 54 """ 55 56 @abstractmethod 57 def reset(self) -> None: 58 """Reset the intrinsic parameters (at the beginning of an epoch).""" 59 60 @abstractmethod 61 def calculate(self) -> Union[float, Dict]: 62 """Calculate the metric (at the end of an epoch). 63 64 Returns 65 ------- 66 result : float | dict 67 either the single value of the metric or a dictionary where the keys are class indices and the values 68 are class metric values 69 70 """
18class Metric(ABC): 19 """Base class for all metrics. 20 21 Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch. 22 If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector; 23 otherwise it should be the final class prediction 24 """ 25 26 needs_raw_data = False 27 """ 28 If `True`, `dlc2action.task.universal_task.Task` will pass raw data to the metric (only primary predict 29 function applied). 30 Otherwise it will pass a prediction for the classes. 31 """ 32 33 def __init__(self) -> None: 34 """Initialize the class.""" 35 self.reset() 36 37 @abstractmethod 38 def update( 39 self, 40 predicted: torch.Tensor, 41 target: torch.Tensor, 42 tags: torch.Tensor, 43 ) -> None: 44 """Update the intrinsic parameters (with a batch). 45 46 Parameters 47 ---------- 48 predicted : torch.Tensor 49 the main prediction tensor generated by the model 50 target : torch.Tensor 51 the corresponding main target tensor 52 tags : torch.Tensor 53 the tensor of meta tags (or `None`, if tags are not given) 54 55 """ 56 57 @abstractmethod 58 def reset(self) -> None: 59 """Reset the intrinsic parameters (at the beginning of an epoch).""" 60 61 @abstractmethod 62 def calculate(self) -> Union[float, Dict]: 63 """Calculate the metric (at the end of an epoch). 64 65 Returns 66 ------- 67 result : float | dict 68 either the single value of the metric or a dictionary where the keys are class indices and the values 69 are class metric values 70 71 """
Base class for all metrics.
Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch. If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector; otherwise it should be the final class prediction
If True, dlc2action.task.universal_task.Task will pass raw data to the metric (only primary predict
function applied).
Otherwise it will pass a prediction for the classes.
37 @abstractmethod 38 def update( 39 self, 40 predicted: torch.Tensor, 41 target: torch.Tensor, 42 tags: torch.Tensor, 43 ) -> None: 44 """Update the intrinsic parameters (with a batch). 45 46 Parameters 47 ---------- 48 predicted : torch.Tensor 49 the main prediction tensor generated by the model 50 target : torch.Tensor 51 the corresponding main target tensor 52 tags : torch.Tensor 53 the tensor of meta tags (or `None`, if tags are not given) 54 55 """
Update the intrinsic parameters (with a batch).
Parameters
predicted : torch.Tensor
the main prediction tensor generated by the model
target : torch.Tensor
the corresponding main target tensor
tags : torch.Tensor
the tensor of meta tags (or None, if tags are not given)
57 @abstractmethod 58 def reset(self) -> None: 59 """Reset the intrinsic parameters (at the beginning of an epoch)."""
Reset the intrinsic parameters (at the beginning of an epoch).
61 @abstractmethod 62 def calculate(self) -> Union[float, Dict]: 63 """Calculate the metric (at the end of an epoch). 64 65 Returns 66 ------- 67 result : float | dict 68 either the single value of the metric or a dictionary where the keys are class indices and the values 69 are class metric values 70 71 """
Calculate the metric (at the end of an epoch).
Returns
result : float | dict either the single value of the metric or a dictionary where the keys are class indices and the values are class metric values