dlc2action.metric.base_metric
Abstract parent class for all metrics
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Abstract parent class for all metrics 8""" 9 10import torch 11from typing import Union, Dict, Any 12from abc import ABC, abstractmethod 13 14 15class Metric(ABC): 16 """ 17 Base class for all metric 18 19 Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch. 20 If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector; 21 otherwise it should be the final class prediction 22 """ 23 24 needs_raw_data = False 25 """ 26 If `True`, `dlc2action.task.universal_task.Task` will pass raw data to the metric (only primary predict 27 function applied). 28 Otherwise it will pass a prediction for the classes. 29 """ 30 31 def __init__(self) -> None: 32 self.reset() 33 34 @abstractmethod 35 def update( 36 self, 37 predicted: torch.Tensor, 38 target: torch.Tensor, 39 tags: torch.Tensor, 40 ) -> None: 41 """ 42 Update the intrinsic parameters (with a batch) 43 44 Parameters 45 ---------- 46 predicted : torch.Tensor 47 the main prediction tensor generated by the model 48 ssl_predicted : torch.Tensor 49 the SSL prediction tensor generated by the model 50 target : torch.Tensor 51 the corresponding main target tensor 52 ssl_target : torch.Tensor 53 the corresponding SSL target tensor 54 tags : torch.Tensor 55 the tensor of meta tags (or `None`, if tags are not given) 56 """ 57 58 @abstractmethod 59 def reset(self) -> None: 60 """ 61 Reset the intrinsic parameters (at the beginning of an epoch) 62 """ 63 64 @abstractmethod 65 def calculate(self) -> Union[float, Dict]: 66 """ 67 Calculate the metric (at the end of an epoch) 68 69 Returns 70 ------- 71 result : float | dict 72 either the single value of the metric or a dictionary where the keys are class indices and the values 73 are class metric values 74 """
16class Metric(ABC): 17 """ 18 Base class for all metric 19 20 Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch. 21 If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector; 22 otherwise it should be the final class prediction 23 """ 24 25 needs_raw_data = False 26 """ 27 If `True`, `dlc2action.task.universal_task.Task` will pass raw data to the metric (only primary predict 28 function applied). 29 Otherwise it will pass a prediction for the classes. 30 """ 31 32 def __init__(self) -> None: 33 self.reset() 34 35 @abstractmethod 36 def update( 37 self, 38 predicted: torch.Tensor, 39 target: torch.Tensor, 40 tags: torch.Tensor, 41 ) -> None: 42 """ 43 Update the intrinsic parameters (with a batch) 44 45 Parameters 46 ---------- 47 predicted : torch.Tensor 48 the main prediction tensor generated by the model 49 ssl_predicted : torch.Tensor 50 the SSL prediction tensor generated by the model 51 target : torch.Tensor 52 the corresponding main target tensor 53 ssl_target : torch.Tensor 54 the corresponding SSL target tensor 55 tags : torch.Tensor 56 the tensor of meta tags (or `None`, if tags are not given) 57 """ 58 59 @abstractmethod 60 def reset(self) -> None: 61 """ 62 Reset the intrinsic parameters (at the beginning of an epoch) 63 """ 64 65 @abstractmethod 66 def calculate(self) -> Union[float, Dict]: 67 """ 68 Calculate the metric (at the end of an epoch) 69 70 Returns 71 ------- 72 result : float | dict 73 either the single value of the metric or a dictionary where the keys are class indices and the values 74 are class metric values 75 """
Base class for all metric
Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch. If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector; otherwise it should be the final class prediction
If True
, dlc2action.task.universal_task.Task
will pass raw data to the metric (only primary predict
function applied).
Otherwise it will pass a prediction for the classes.
35 @abstractmethod 36 def update( 37 self, 38 predicted: torch.Tensor, 39 target: torch.Tensor, 40 tags: torch.Tensor, 41 ) -> None: 42 """ 43 Update the intrinsic parameters (with a batch) 44 45 Parameters 46 ---------- 47 predicted : torch.Tensor 48 the main prediction tensor generated by the model 49 ssl_predicted : torch.Tensor 50 the SSL prediction tensor generated by the model 51 target : torch.Tensor 52 the corresponding main target tensor 53 ssl_target : torch.Tensor 54 the corresponding SSL target tensor 55 tags : torch.Tensor 56 the tensor of meta tags (or `None`, if tags are not given) 57 """
Update the intrinsic parameters (with a batch)
Parameters
predicted : torch.Tensor
the main prediction tensor generated by the model
ssl_predicted : torch.Tensor
the SSL prediction tensor generated by the model
target : torch.Tensor
the corresponding main target tensor
ssl_target : torch.Tensor
the corresponding SSL target tensor
tags : torch.Tensor
the tensor of meta tags (or None
, if tags are not given)
59 @abstractmethod 60 def reset(self) -> None: 61 """ 62 Reset the intrinsic parameters (at the beginning of an epoch) 63 """
Reset the intrinsic parameters (at the beginning of an epoch)
65 @abstractmethod 66 def calculate(self) -> Union[float, Dict]: 67 """ 68 Calculate the metric (at the end of an epoch) 69 70 Returns 71 ------- 72 result : float | dict 73 either the single value of the metric or a dictionary where the keys are class indices and the values 74 are class metric values 75 """
Calculate the metric (at the end of an epoch)
Returns
result : float | dict either the single value of the metric or a dictionary where the keys are class indices and the values are class metric values