dlc2action.metric.base_metric

Abstract parent class for all metrics.

 1#
 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. 
 5# A copy is included in dlc2action/LICENSE.AGPL.
 6#
 7"""Abstract parent class for all metrics."""
 8
 9from abc import ABC, abstractmethod
10from typing import Any, Dict, Union
11
12import torch
13
14torch.manual_seed(0)
15
16
17class Metric(ABC):
18    """Base class for all metrics.
19
20    Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch.
21    If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector;
22    otherwise it should be the final class prediction
23    """
24
25    needs_raw_data = False
26    """
27    If `True`, `dlc2action.task.universal_task.Task` will pass raw data to the metric (only primary predict
28    function applied).
29    Otherwise it will pass a prediction for the classes.
30    """
31
32    def __init__(self) -> None:
33        """Initialize the class."""
34        self.reset()
35
36    @abstractmethod
37    def update(
38        self,
39        predicted: torch.Tensor,
40        target: torch.Tensor,
41        tags: torch.Tensor,
42    ) -> None:
43        """Update the intrinsic parameters (with a batch).
44
45        Parameters
46        ----------
47        predicted : torch.Tensor
48            the main prediction tensor generated by the model
49        target : torch.Tensor
50            the corresponding main target tensor
51        tags : torch.Tensor
52            the tensor of meta tags (or `None`, if tags are not given)
53
54        """
55
56    @abstractmethod
57    def reset(self) -> None:
58        """Reset the intrinsic parameters (at the beginning of an epoch)."""
59
60    @abstractmethod
61    def calculate(self) -> Union[float, Dict]:
62        """Calculate the metric (at the end of an epoch).
63
64        Returns
65        -------
66        result : float | dict
67            either the single value of the metric or a dictionary where the keys are class indices and the values
68            are class metric values
69
70        """
class Metric(abc.ABC):
18class Metric(ABC):
19    """Base class for all metrics.
20
21    Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch.
22    If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector;
23    otherwise it should be the final class prediction
24    """
25
26    needs_raw_data = False
27    """
28    If `True`, `dlc2action.task.universal_task.Task` will pass raw data to the metric (only primary predict
29    function applied).
30    Otherwise it will pass a prediction for the classes.
31    """
32
33    def __init__(self) -> None:
34        """Initialize the class."""
35        self.reset()
36
37    @abstractmethod
38    def update(
39        self,
40        predicted: torch.Tensor,
41        target: torch.Tensor,
42        tags: torch.Tensor,
43    ) -> None:
44        """Update the intrinsic parameters (with a batch).
45
46        Parameters
47        ----------
48        predicted : torch.Tensor
49            the main prediction tensor generated by the model
50        target : torch.Tensor
51            the corresponding main target tensor
52        tags : torch.Tensor
53            the tensor of meta tags (or `None`, if tags are not given)
54
55        """
56
57    @abstractmethod
58    def reset(self) -> None:
59        """Reset the intrinsic parameters (at the beginning of an epoch)."""
60
61    @abstractmethod
62    def calculate(self) -> Union[float, Dict]:
63        """Calculate the metric (at the end of an epoch).
64
65        Returns
66        -------
67        result : float | dict
68            either the single value of the metric or a dictionary where the keys are class indices and the values
69            are class metric values
70
71        """

Base class for all metrics.

Metrics are reset at the beginning of each epoch, updated with batch data and then calculated at the end of the epoch. If needs_raw_data is True for a metric class, it should expect to receive raw model output as the predicted vector; otherwise it should be the final class prediction

Metric()
33    def __init__(self) -> None:
34        """Initialize the class."""
35        self.reset()

Initialize the class.

needs_raw_data = False

If True, dlc2action.task.universal_task.Task will pass raw data to the metric (only primary predict function applied). Otherwise it will pass a prediction for the classes.

@abstractmethod
def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor) -> None:
37    @abstractmethod
38    def update(
39        self,
40        predicted: torch.Tensor,
41        target: torch.Tensor,
42        tags: torch.Tensor,
43    ) -> None:
44        """Update the intrinsic parameters (with a batch).
45
46        Parameters
47        ----------
48        predicted : torch.Tensor
49            the main prediction tensor generated by the model
50        target : torch.Tensor
51            the corresponding main target tensor
52        tags : torch.Tensor
53            the tensor of meta tags (or `None`, if tags are not given)
54
55        """

Update the intrinsic parameters (with a batch).

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)

@abstractmethod
def reset(self) -> None:
57    @abstractmethod
58    def reset(self) -> None:
59        """Reset the intrinsic parameters (at the beginning of an epoch)."""

Reset the intrinsic parameters (at the beginning of an epoch).

@abstractmethod
def calculate(self) -> Union[float, Dict]:
61    @abstractmethod
62    def calculate(self) -> Union[float, Dict]:
63        """Calculate the metric (at the end of an epoch).
64
65        Returns
66        -------
67        result : float | dict
68            either the single value of the metric or a dictionary where the keys are class indices and the values
69            are class metric values
70
71        """

Calculate the metric (at the end of an epoch).

Returns

result : float | dict either the single value of the metric or a dictionary where the keys are class indices and the values are class metric values