dlc2action.metric.base_metric

Abstract parent class for all metrics

 1#
 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
 5#
 6"""
 7Abstract parent class for all metrics
 8"""
 9
10import torch
11from typing import Union, Dict, Any
12from abc import ABC, abstractmethod
13
14
15class Metric(ABC):
16    """
17    Base class for all metric
18
19    Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch.
20    If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector;
21    otherwise it should be the final class prediction
22    """
23
24    needs_raw_data = False
25    """
26    If `True`, `dlc2action.task.universal_task.Task` will pass raw data to the metric (only primary predict 
27    function applied).
28    Otherwise it will pass a prediction for the classes.
29    """
30
31    def __init__(self) -> None:
32        self.reset()
33
34    @abstractmethod
35    def update(
36        self,
37        predicted: torch.Tensor,
38        target: torch.Tensor,
39        tags: torch.Tensor,
40    ) -> None:
41        """
42        Update the intrinsic parameters (with a batch)
43
44        Parameters
45        ----------
46        predicted : torch.Tensor
47            the main prediction tensor generated by the model
48        ssl_predicted : torch.Tensor
49            the SSL prediction tensor generated by the model
50        target : torch.Tensor
51            the corresponding main target tensor
52        ssl_target : torch.Tensor
53            the corresponding SSL target tensor
54        tags : torch.Tensor
55            the tensor of meta tags (or `None`, if tags are not given)
56        """
57
58    @abstractmethod
59    def reset(self) -> None:
60        """
61        Reset the intrinsic parameters (at the beginning of an epoch)
62        """
63
64    @abstractmethod
65    def calculate(self) -> Union[float, Dict]:
66        """
67        Calculate the metric (at the end of an epoch)
68
69        Returns
70        -------
71        result : float | dict
72            either the single value of the metric or a dictionary where the keys are class indices and the values
73            are class metric values
74        """
class Metric(abc.ABC):
16class Metric(ABC):
17    """
18    Base class for all metric
19
20    Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch.
21    If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector;
22    otherwise it should be the final class prediction
23    """
24
25    needs_raw_data = False
26    """
27    If `True`, `dlc2action.task.universal_task.Task` will pass raw data to the metric (only primary predict 
28    function applied).
29    Otherwise it will pass a prediction for the classes.
30    """
31
32    def __init__(self) -> None:
33        self.reset()
34
35    @abstractmethod
36    def update(
37        self,
38        predicted: torch.Tensor,
39        target: torch.Tensor,
40        tags: torch.Tensor,
41    ) -> None:
42        """
43        Update the intrinsic parameters (with a batch)
44
45        Parameters
46        ----------
47        predicted : torch.Tensor
48            the main prediction tensor generated by the model
49        ssl_predicted : torch.Tensor
50            the SSL prediction tensor generated by the model
51        target : torch.Tensor
52            the corresponding main target tensor
53        ssl_target : torch.Tensor
54            the corresponding SSL target tensor
55        tags : torch.Tensor
56            the tensor of meta tags (or `None`, if tags are not given)
57        """
58
59    @abstractmethod
60    def reset(self) -> None:
61        """
62        Reset the intrinsic parameters (at the beginning of an epoch)
63        """
64
65    @abstractmethod
66    def calculate(self) -> Union[float, Dict]:
67        """
68        Calculate the metric (at the end of an epoch)
69
70        Returns
71        -------
72        result : float | dict
73            either the single value of the metric or a dictionary where the keys are class indices and the values
74            are class metric values
75        """

Base class for all metric

Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch. If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector; otherwise it should be the final class prediction

needs_raw_data = False

If True, dlc2action.task.universal_task.Task will pass raw data to the metric (only primary predict function applied). Otherwise it will pass a prediction for the classes.

@abstractmethod
def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor) -> None:
35    @abstractmethod
36    def update(
37        self,
38        predicted: torch.Tensor,
39        target: torch.Tensor,
40        tags: torch.Tensor,
41    ) -> None:
42        """
43        Update the intrinsic parameters (with a batch)
44
45        Parameters
46        ----------
47        predicted : torch.Tensor
48            the main prediction tensor generated by the model
49        ssl_predicted : torch.Tensor
50            the SSL prediction tensor generated by the model
51        target : torch.Tensor
52            the corresponding main target tensor
53        ssl_target : torch.Tensor
54            the corresponding SSL target tensor
55        tags : torch.Tensor
56            the tensor of meta tags (or `None`, if tags are not given)
57        """

Update the intrinsic parameters (with a batch)

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model ssl_predicted : torch.Tensor the SSL prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor ssl_target : torch.Tensor the corresponding SSL target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)

@abstractmethod
def reset(self) -> None:
59    @abstractmethod
60    def reset(self) -> None:
61        """
62        Reset the intrinsic parameters (at the beginning of an epoch)
63        """

Reset the intrinsic parameters (at the beginning of an epoch)

@abstractmethod
def calculate(self) -> Union[float, Dict]:
65    @abstractmethod
66    def calculate(self) -> Union[float, Dict]:
67        """
68        Calculate the metric (at the end of an epoch)
69
70        Returns
71        -------
72        result : float | dict
73            either the single value of the metric or a dictionary where the keys are class indices and the values
74            are class metric values
75        """

Calculate the metric (at the end of an epoch)

Returns

result : float | dict either the single value of the metric or a dictionary where the keys are class indices and the values are class metric values