dlc2action.metric.metrics

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""Implementations of `dlc2action.metric.base_metric.Metric`."""
   8
   9import warnings
  10from abc import abstractmethod
  11from collections import defaultdict
  12from copy import copy, deepcopy
  13from typing import Any, Dict, List, Set, Tuple, Union
  14
  15import editdistance
  16import numpy as np
  17import torch
  18from dlc2action.metric.base_metric import Metric
  19from sklearn import metrics
  20
  21
  22class _ClassificationMetric(Metric):
  23    """The base class for all metric that are calculated from true and false negative and positive rates."""
  24
  25    needs_raw_data = True
  26    segmental = False
  27    """
  28    If `True`, the metric will be calculated over segments; otherwise over frames.
  29    """
  30
  31    def __init__(
  32        self,
  33        num_classes: int,
  34        ignore_index: int = -100,
  35        average: str = "macro",
  36        ignored_classes: Set = None,
  37        exclusive: bool = True,
  38        iou_threshold: float = 0.5,
  39        tag_average: str = "micro",
  40        threshold_values: List = None,
  41        integration_interval: int = 0,
  42    ):
  43        """Initialize the metric.
  44
  45        Parameters
  46        ----------
  47        ignore_index : int, default -100
  48            the class index that indicates ignored samples
  49        average: {'macro', 'micro', 'none'}
  50            method for averaging across classes
  51        num_classes : int, optional
  52            number of classes (not necessary if main_class is not None)
  53        ignored_classes : set, optional
  54            a set of class ids to ignore in calculation
  55        exclusive: bool, default True
  56            set to False for multi-label classification tasks
  57        iou_threshold : float, default 0.5
  58            if segmental is true, intervals with IoU larger than this threshold are considered correct
  59        tag_average: {'micro', 'macro', 'none'}
  60            method for averaging across meta tags (if given)
  61        threshold_values : List, optional
  62            a list of float values between 0 and 1 that will be used as decision thresholds
  63        integration_interval : int, default 0
  64            if not 0, the metric will be calculated over a window of this size (in frames) around the current frame
  65
  66        """
  67        super().__init__()
  68        self.ignore_index = ignore_index
  69        self.average = average
  70        self.tag_average = tag_average
  71        self.tags = set()
  72        self.exclusive = exclusive
  73        self.threshold = iou_threshold
  74        self.integration_interval = integration_interval
  75        self.classes = list(range(int(num_classes)))
  76        self.optimize = False
  77        if threshold_values is None:
  78            threshold_values = [0.5]
  79        elif threshold_values == ["optimize"]:
  80            threshold_values = list(np.arange(0.25, 0.75, 0.05))
  81            self.optimize = True
  82        if self.exclusive and threshold_values != [0.5]:
  83            raise ValueError(
  84                "Cannot set threshold values for exclusive classification!"
  85            )
  86        self.threshold_values = threshold_values
  87
  88        if ignored_classes is not None:
  89            for c in ignored_classes:
  90                if c in self.classes:
  91                    self.classes.remove(c)
  92        if len(self.classes) == 0:
  93            warnings.warn("No classes are followed!")
  94
  95    def reset(self) -> None:
  96        """Reset the internal state (at the beginning of an epoch)."""
  97        self.tags = set()
  98        self.tp = defaultdict(lambda: defaultdict(lambda: 0))
  99        self.fp = defaultdict(lambda: defaultdict(lambda: 0))
 100        self.fn = defaultdict(lambda: defaultdict(lambda: 0))
 101        self.tn = defaultdict(lambda: defaultdict(lambda: 0))
 102
 103    def update(
 104        self,
 105        predicted: torch.Tensor,
 106        target: torch.Tensor,
 107        tags: torch.Tensor,
 108    ) -> None:
 109        """Update the internal state (with a batch).
 110
 111        Parameters
 112        ----------
 113        predicted : torch.Tensor
 114            the main prediction tensor generated by the model
 115        ssl_predicted : torch.Tensor
 116            the SSL prediction tensor generated by the model
 117        target : torch.Tensor
 118            the corresponding main target tensor
 119        ssl_target : torch.Tensor
 120            the corresponding SSL target tensor
 121        tags : torch.Tensor
 122            the tensor of meta tags (or `None`, if tags are not given)
 123
 124        """
 125        if self.segmental:
 126            self._update_segmental(predicted, target, tags)
 127        else:
 128            self._update_normal(predicted, target, tags)
 129
 130    def _key(self, tag: torch.Tensor, c: int) -> str:
 131        """Get a key for the intermediate value dictionaries from tag and class indices."""
 132        if tag is None or self.tag_average == "micro":
 133            return c
 134        else:
 135            return f"tag{int(tag)}_{c}"
 136
 137    def _update_normal(
 138        self,
 139        predicted: torch.Tensor,
 140        target: torch.Tensor,
 141        tags: torch.Tensor,
 142    ) -> None:
 143        """Update the internal state (with a batch), calculating over frames."""
 144        if self.integration_interval != 0:
 145            pr = 0
 146            denom = torch.ones((1, 1, predicted.shape[-1])) * (
 147                2 * self.integration_interval + 1
 148            )
 149            for i in range(self.integration_interval):
 150                pr += torch.cat(
 151                    [
 152                        torch.zeros((predicted.shape[0], predicted.shape[1], i + 1)),
 153                        predicted[:, :, i + 1 :],
 154                    ],
 155                    dim=-1,
 156                )
 157                pr += torch.cat(
 158                    [
 159                        predicted[:, :, : -i - 1],
 160                        torch.zeros((predicted.shape[0], predicted.shape[1], i + 1)),
 161                    ],
 162                    dim=-1,
 163                )
 164                denom[:, :, i] = i + 1
 165                denom[:, :, -i] = i + 1
 166            predicted = pr / denom
 167        if self.exclusive:
 168            predicted = torch.max(predicted, 1)[1]
 169        if self.tag_average == "micro" or tags is None or tags[0] is None:
 170            tag_set = {None}
 171        else:
 172            tag_set = set(tags)
 173            self.tags.update(tag_set)
 174        for thr in self.threshold_values:
 175            for t in tag_set:
 176                if t is None:
 177                    predicted_t = predicted
 178                    target_t = target
 179                else:
 180                    predicted_t = predicted[tags == t]
 181                    target_t = target[tags == t]
 182                for c in self.classes:
 183                    if self.exclusive:
 184                        pos = (predicted_t == c) * (target_t != self.ignore_index)
 185                        neg = (predicted_t != c) * (target_t != self.ignore_index)
 186                        self.tp[self._key(t, c)][thr] += ((target_t == c) * pos).sum()
 187                        self.fp[self._key(t, c)][thr] += ((target_t != c) * pos).sum()
 188                        self.fn[self._key(t, c)][thr] += ((target_t == c) * neg).sum()
 189                        self.tn[self._key(t, c)][thr] += ((target_t != c) * neg).sum()
 190                    else:
 191                        if isinstance(thr, list):
 192                            key = ", ".join(map(str, thr))
 193                            for i, tt in enumerate(thr):
 194                                predicted_t[:, i, :] = (predicted_t[:, i, :] > tt).int()
 195                        else:
 196                            key = thr
 197                            predicted_t = (predicted_t > thr).int()
 198                        pos = predicted_t[:, c, :] * (
 199                            target_t[:, c, :] != self.ignore_index
 200                        )
 201                        neg = (predicted_t[:, c, :] != 1) * (
 202                            target_t[:, c, :] != self.ignore_index
 203                        )
 204                        self.tp[self._key(t, c)][key] += (target_t[:, c, :] * pos).sum()
 205                        self.fp[self._key(t, c)][key] += (
 206                            (target_t[:, c, :] != 1) * pos
 207                        ).sum()
 208                        self.fn[self._key(t, c)][key] += (target_t[:, c, :] * neg).sum()
 209                        self.tn[self._key(t, c)][key] += (
 210                            (target_t[:, c, :] != 1) * neg
 211                        ).sum()
 212
 213    def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor:
 214        """Get a list of `True` group beginning and end indices from a boolean tensor."""
 215        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
 216        true_indices = torch.where(output)[0]
 217        starts = torch.tensor(
 218            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
 219        )
 220        ends = torch.tensor(
 221            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
 222        )
 223        return torch.stack([starts, ends]).T
 224
 225    def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor:
 226        """Get rid of jittering in a non-exclusive classification tensor.
 227
 228        First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
 229        `smooth_interval`.
 230
 231        """
 232        for c in self.classes:
 233            intervals = self._get_intervals(tensor[:, c] == 0)
 234            interval_lengths = torch.tensor(
 235                [interval[1] - interval[0] for interval in intervals]
 236            )
 237            short_intervals = intervals[interval_lengths <= smooth_interval]
 238            for start, end in short_intervals:
 239                tensor[start:end, c] = 1
 240            intervals = self._get_intervals(tensor[:, c] == 1)
 241            interval_lengths = torch.tensor(
 242                [interval[1] - interval[0] for interval in intervals]
 243            )
 244            short_intervals = intervals[interval_lengths <= smooth_interval]
 245            for start, end in short_intervals:
 246                tensor[start:end, c] = 0
 247        return tensor
 248
 249    def _sigmoid_threshold_function(
 250        self,
 251        low_threshold: float,
 252        high_threshold: float,
 253        low_length: int,
 254        high_length: int,
 255    ):
 256        """Generate a sigmoid threshold function.
 257
 258        The resulting function outputs an intersection threshold given the length of the interval.
 259
 260        """
 261        a = 2 / (high_length - low_length)
 262        b = 1 - a * high_length
 263        return lambda x: low_threshold + torch.sigmoid(4 * (a * x + b)) * (
 264            high_threshold - low_threshold
 265        )
 266
 267    def _update_segmental(
 268        self,
 269        predicted: torch.tensor,
 270        target: torch.tensor,
 271        tags: torch.Tensor,
 272    ) -> None:
 273        """Update the internal state (with a batch), calculating over segments."""
 274        if self.exclusive:
 275            predicted = torch.max(predicted, 1)[1]
 276        predicted = torch.cat(
 277            [
 278                copy(predicted),
 279                -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device),
 280            ],
 281            dim=-1,
 282        )
 283        target = torch.cat(
 284            [
 285                copy(target),
 286                -100 * torch.ones((*target.shape[:-1], 1)).to(target.device),
 287            ],
 288            dim=-1,
 289        )
 290        if self.exclusive:
 291            predicted = predicted.flatten()
 292            target = target.flatten()
 293        else:
 294            num_classes = predicted.shape[1]
 295            predicted = predicted.transpose(1, 2).reshape(-1, num_classes)
 296            target = target.transpose(1, 2).reshape(-1, num_classes)
 297        if self.tag_average == "micro" or tags is None or tags[0] is None:
 298            tag_set = {None}
 299        else:
 300            tag_set = set(tags)
 301            self.tags.update(tag_set)
 302        for thr in self.threshold_values:
 303            key = thr
 304            for t in tag_set:
 305                if t is None:
 306                    predicted_t = predicted
 307                    target_t = target
 308                else:
 309                    predicted_t = predicted[tags == t]
 310                    target_t = target[tags == t]
 311                if not self.exclusive:
 312                    if isinstance(thr, list):
 313                        for i, tt in enumerate(thr):
 314                            predicted_t[i, :] = (predicted_t[i, :] > tt).int()
 315                        key = ", ".join(map(str, thr))
 316                    else:
 317                        predicted_t = (predicted_t > thr).int()
 318                for c in self.classes:
 319                    if self.exclusive:
 320                        predicted_intervals = self._get_intervals(predicted_t == c)
 321                        target_intervals = self._get_intervals(target_t == c)
 322                    else:
 323                        predicted_intervals = self._get_intervals(
 324                            predicted_t[:, c] == 1
 325                        )
 326                        target_intervals = self._get_intervals(target_t[:, c] == 1)
 327                    true_used = torch.zeros(target_intervals.shape[0])
 328                    for interval in predicted_intervals:
 329                        if len(target_intervals) > 0:
 330                            # Compute IoU against all others
 331                            intersection = torch.minimum(
 332                                interval[1], target_intervals[:, 1]
 333                            ) - torch.maximum(interval[0], target_intervals[:, 0])
 334                            union = torch.maximum(
 335                                interval[1], target_intervals[:, 1]
 336                            ) - torch.minimum(interval[0], target_intervals[:, 0])
 337                            IoU = intersection / union
 338
 339                            # Get the best scoring segment
 340                            idx = IoU.argmax()
 341
 342                            # If the IoU is high enough and the true segment isn't already used
 343                            # Then it is a true positive. Otherwise is it a false positive.
 344                            if IoU[idx] >= self.threshold and not true_used[idx]:
 345                                self.tp[self._key(t, c)][key] += 1
 346                                true_used[idx] = 1
 347                            else:
 348                                self.fp[self._key(t, c)][key] += 1
 349                        else:
 350                            self.fp[self._key(t, c)][key] += 1
 351                    self.fn[self._key(t, c)][key] += len(true_used) - torch.sum(
 352                        true_used
 353                    )
 354
 355    def calculate(self) -> Union[float, Dict]:
 356        """Calculate the metric (at the end of an epoch).
 357
 358        Returns
 359        -------
 360        result : float | dict
 361            either the single value of the metric or a dictionary where the keys are class indices and the values
 362            are class metric values
 363
 364        """
 365        if self.tag_average == "micro" or self.tags == {None}:
 366            self.tags = {None}
 367        result = {}
 368        self.tags = sorted(list(self.tags))
 369        for tag in self.tags:
 370            if self.average == "macro":
 371                metric_list = []
 372                for c in self.classes:
 373                    metric_list.append(
 374                        self._calculate_metric(
 375                            self.tp[self._key(tag, c)],
 376                            self.fp[self._key(tag, c)],
 377                            self.fn[self._key(tag, c)],
 378                            self.tn[self._key(tag, c)],
 379                        )
 380                    )
 381                if tag is not None:
 382                    tag = int(tag)
 383                result[f"tag{tag}"] = sum(metric_list) / len(metric_list)
 384            elif self.average == "micro":
 385                # tp = sum([self.tp[self._key(tag, c)] for c in self.classes])
 386                # fn = sum([self.fn[self._key(tag, c)] for c in self.classes])
 387                # fp = sum([self.fp[self._key(tag, c)] for c in self.classes])
 388                # tn = sum([self.tn[self._key(tag, c)] for c in self.classes])
 389                # if tag is not None:
 390                #     tag = int(tag)
 391                # result[f"tag{tag}"] = self._calculate_metric(tp, fp, fn, tn)
 392                metric_list = []
 393                for c in self.classes:
 394                    metric_list.append(
 395                        self._calculate_metric(
 396                            self.tp[self._key(tag, c)],
 397                            self.fp[self._key(tag, c)],
 398                            self.fn[self._key(tag, c)],
 399                            self.tn[self._key(tag, c)],
 400                        )
 401                    )
 402
 403                if tag is not None:
 404                    tag = int(tag)
 405                result[f"tag{tag}"] = metric_list
 406            elif self.average == "none":
 407                metric_dict = {}
 408                for c in self.classes:
 409                    metric_dict[self._key(tag, c)] = self._calculate_metric(
 410                        self.tp[self._key(tag, c)],
 411                        self.fp[self._key(tag, c)],
 412                        self.fn[self._key(tag, c)],
 413                        self.tn[self._key(tag, c)],
 414                    )
 415                result.update(metric_dict)
 416            else:
 417                raise ValueError(
 418                    f"The {self.average} averaging method is not available, please choose from "
 419                    f'["none", "micro", "macro"]'
 420                )
 421        if len(self.tags) == 1 and self.average != "none":
 422            tag = self.tags[0]
 423            if tag is not None:
 424                tag = int(tag)
 425            result = result[f"tag{tag}"]
 426        elif self.tag_average == "macro":
 427            if self.average == "none":
 428                r = {}
 429                for c in self.classes:
 430                    r[c] = torch.mean(
 431                        torch.tensor([result[self._key(tag, c)] for tag in self.tags])
 432                    )
 433            else:
 434                r = torch.mean(
 435                    torch.tensor([result[f"tag{int(tag)}"] for tag in self.tags])
 436                )
 437            result = r
 438        return result
 439
 440    @abstractmethod
 441    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 442        """Calculate the metric value from true and false positive and negative rates.
 443
 444        Parameters
 445        ----------
 446        tp : float
 447            true positive rate
 448        fp: float
 449            false positive rate
 450        fn: float
 451            false negative rate
 452        tn: float
 453            true negative rate
 454
 455        Returns
 456        -------
 457        metric : float
 458            metric value
 459
 460        """
 461
 462
 463class PR_AUC(_ClassificationMetric):
 464    """Area under precision-recall curve (not advised for training)."""
 465
 466    def __init__(
 467        self,
 468        num_classes: int,
 469        ignore_index: int = -100,
 470        average: str = "macro",
 471        ignored_classes: Set = None,
 472        exclusive: bool = False,
 473        tag_average: str = "micro",
 474        threshold_step: float = 0.1,
 475    ):
 476        """Initialize the metric.
 477
 478        Parameters
 479        ----------
 480        ignore_index : int, default -100
 481            the class index that indicates ignored samples
 482        average: {'macro', 'micro', 'none'}
 483            method for averaging across classes
 484        num_classes : int, optional
 485            number of classes (not necessary if main_class is not None)
 486        ignored_classes : set, optional
 487            a set of class ids to ignore in calculation
 488        exclusive: bool, default True
 489            set to False for multi-label classification tasks
 490        tag_average: {'micro', 'macro', 'none'}
 491            method for averaging across meta tags (if given)
 492        threshold_step : float, default 0.1
 493            the decision threshold step
 494
 495        """
 496        if exclusive:
 497            raise ValueError(
 498                "The PR-AUC metric is not implemented for exclusive classification!"
 499            )
 500        super().__init__(
 501            num_classes,
 502            ignore_index,
 503            average,
 504            ignored_classes,
 505            exclusive,
 506            tag_average=tag_average,
 507            threshold_values=list(np.arange(0, 1, threshold_step)),
 508        )
 509
 510    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 511        """Calculate the metric value from true and false positive and negative rates."""
 512        precisions = []
 513        recalls = []
 514        for k in sorted(self.threshold_values):
 515            precisions.append(tp[k] / (tp[k] + fp[k] + 1e-7))
 516            recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7))
 517        return metrics.auc(x=recalls, y=precisions)
 518
 519
 520class Precision(_ClassificationMetric):
 521    """Precision."""
 522
 523    def __init__(
 524        self,
 525        num_classes: int,
 526        ignore_index: int = -100,
 527        average: str = "macro",
 528        ignored_classes: Set = None,
 529        exclusive: bool = True,
 530        tag_average: str = "micro",
 531        threshold_value: Union[float, List] = None,
 532    ):
 533        """Initialize the class.
 534
 535        Parameters
 536        ----------
 537        ignore_index : int, default -100
 538            the class index that indicates ignored samples
 539        average: {'macro', 'micro', 'none'}
 540            method for averaging across classes
 541        num_classes : int, optional
 542            number of classes (not necessary if main_class is not None)
 543        ignored_classes : set, optional
 544            a set of class ids to ignore in calculation
 545        exclusive: bool, default True
 546            set to False for multi-label classification tasks
 547        tag_average: {'micro', 'macro', 'none'}
 548            method for averaging across meta tags (if given)
 549        threshold_value : float | list, optional
 550            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 551            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 552            under the same index
 553
 554        """
 555        if threshold_value is None:
 556            threshold_value = 0.5
 557        super().__init__(
 558            num_classes,
 559            ignore_index,
 560            average,
 561            ignored_classes,
 562            exclusive,
 563            tag_average=tag_average,
 564            threshold_values=[threshold_value],
 565        )
 566
 567    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 568        """Calculate the metric value from true and false positive and negative rates."""
 569        k = self.threshold_values[0]
 570        if isinstance(k, list):
 571            k = ", ".join(map(str, k))
 572        return tp[k] / (tp[k] + fp[k] + 1e-7)
 573
 574
 575class SegmentalPrecision(_ClassificationMetric):
 576    """Segmental precision (not advised for training)."""
 577
 578    segmental = True
 579
 580    def __init__(
 581        self,
 582        num_classes: int,
 583        ignore_index: int = -100,
 584        average: str = "macro",
 585        ignored_classes: Set = None,
 586        exclusive: bool = True,
 587        iou_threshold: float = 0.5,
 588        tag_average: str = "micro",
 589        threshold_value: Union[float, List] = None,
 590    ):
 591        """Initialize the class.
 592
 593        Parameters
 594        ----------
 595        ignore_index : int, default -100
 596            the class index that indicates ignored samples
 597        average: {'macro', 'micro', 'none'}
 598            method for averaging across classes
 599        num_classes : int, optional
 600            number of classes (not necessary if main_class is not None)
 601        ignored_classes : set, optional
 602            a set of class ids to ignore in calculation
 603        exclusive: bool, default True
 604            set to False for multi-label classification tasks
 605        iou_threshold : float, default 0.5
 606            if segmental is true, intervals with IoU larger than this threshold are considered correct
 607        tag_average: {'micro', 'macro', 'none'}
 608            method for averaging across meta tags (if given)
 609        threshold_value : float | list, optional
 610            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 611            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 612            under the same index
 613
 614        """
 615        if threshold_value is None:
 616            threshold_value = 0.5
 617        super().__init__(
 618            num_classes,
 619            ignore_index,
 620            average,
 621            ignored_classes,
 622            exclusive,
 623            iou_threshold,
 624            tag_average,
 625            [threshold_value],
 626        )
 627
 628    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 629        """Calculate the metric value from true and false positive and negative rates."""
 630        k = self.threshold_values[0]
 631        if isinstance(k, list):
 632            k = ", ".join(map(str, k))
 633        return tp[k] / (tp[k] + fp[k] + 1e-7)
 634
 635
 636class Recall(_ClassificationMetric):
 637    """Recall."""
 638
 639    def __init__(
 640        self,
 641        num_classes: int,
 642        ignore_index: int = -100,
 643        average: str = "macro",
 644        ignored_classes: Set = None,
 645        exclusive: bool = True,
 646        tag_average: str = "micro",
 647        threshold_value: Union[float, List] = None,
 648    ):
 649        """Initialize the class.
 650
 651        Parameters
 652        ----------
 653        ignore_index : int, default -100
 654            the class index that indicates ignored samples
 655        average: {'macro', 'micro', 'none'}
 656            method for averaging across classes
 657        num_classes : int, optional
 658            number of classes (not necessary if main_class is not None)
 659        ignored_classes : set, optional
 660            a set of class ids to ignore in calculation
 661        exclusive: bool, default True
 662            set to False for multi-label classification tasks
 663        tag_average: {'micro', 'macro', 'none'}
 664            method for averaging across meta tags (if given)
 665        threshold_value : float | list, optional
 666            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 667            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 668            under the same index
 669
 670        """
 671        if threshold_value is None:
 672            threshold_value = 0.5
 673        super().__init__(
 674            num_classes,
 675            ignore_index,
 676            average,
 677            ignored_classes,
 678            exclusive,
 679            tag_average=tag_average,
 680            threshold_values=[threshold_value],
 681        )
 682
 683    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 684        """Calculate the metric value from true and false positive and negative rates."""
 685        k = self.threshold_values[0]
 686        if isinstance(k, list):
 687            k = ", ".join(map(str, k))
 688        return tp[k] / (tp[k] + fn[k] + 1e-7)
 689
 690
 691class SegmentalRecall(_ClassificationMetric):
 692    """Segmental recall (not advised for training)."""
 693
 694    segmental = True
 695
 696    def __init__(
 697        self,
 698        num_classes: int,
 699        ignore_index: int = -100,
 700        average: str = "macro",
 701        ignored_classes: Set = None,
 702        exclusive: bool = True,
 703        iou_threshold: float = 0.5,
 704        tag_average: str = "micro",
 705        threshold_value: Union[float, List] = None,
 706    ):
 707        """Initialize the class.
 708
 709        Parameters
 710        ----------
 711        ignore_index : int, default -100
 712            the class index that indicates ignored samples
 713        average: {'macro', 'micro', 'none'}
 714            method for averaging across classes
 715        num_classes : int, optional
 716            number of classes (not necessary if main_class is not None)
 717        ignored_classes : set, optional
 718            a set of class ids to ignore in calculation
 719        exclusive: bool, default True
 720            set to False for multi-label classification tasks
 721        iou_threshold : float, default 0.5
 722            if segmental is true, intervals with IoU larger than this threshold are considered correct
 723        tag_average: {'micro', 'macro', 'none'}
 724            method for averaging across meta tags (if given)
 725        threshold_value : float | list, optional
 726            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 727            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 728            under the same index
 729
 730        """
 731        if threshold_value is None:
 732            threshold_value = 0.5
 733        super().__init__(
 734            num_classes,
 735            ignore_index,
 736            average,
 737            ignored_classes,
 738            exclusive,
 739            iou_threshold,
 740            tag_average,
 741            [threshold_value],
 742        )
 743
 744    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 745        """Calculate the metric value from true and false positive and negative rates."""
 746        k = self.threshold_values[0]
 747        if isinstance(k, list):
 748            k = ", ".join(map(str, k))
 749        return tp[k] / (tp[k] + fn[k] + 1e-7)
 750
 751
 752class F1(_ClassificationMetric):
 753    """F1 score."""
 754
 755    def __init__(
 756        self,
 757        num_classes: int,
 758        ignore_index: int = -100,
 759        average: str = "macro",
 760        ignored_classes: Set = None,
 761        exclusive: bool = True,
 762        tag_average: str = "micro",
 763        threshold_value: Union[float, List] = None,
 764        integration_interval: int = 0,
 765    ):
 766        """Initialize the class.
 767
 768        Parameters
 769        ----------
 770        ignore_index : int, default -100
 771            the class index that indicates ignored samples
 772        average: {'macro', 'micro', 'none'}
 773            method for averaging across classes
 774        num_classes : int, optional
 775            number of classes (not necessary if main_class is not None)
 776        ignored_classes : set, optional
 777            a set of class ids to ignore in calculation
 778        exclusive: bool, default True
 779            set to False for multi-label classification tasks
 780        tag_average: {'micro', 'macro', 'none'}
 781            method for averaging across meta tags (if given)
 782        threshold_value : float | list, optional
 783            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 784            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 785            under the same index
 786        integration_interval : int, default 0
 787            the number of frames to integrate over (0 means no integration)
 788
 789        """
 790        if threshold_value is None:
 791            threshold_value = 0.5
 792        super().__init__(
 793            num_classes,
 794            ignore_index,
 795            average,
 796            ignored_classes,
 797            exclusive,
 798            tag_average=tag_average,
 799            threshold_values=[threshold_value],
 800            integration_interval=integration_interval,
 801        )
 802
 803    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 804        """Calculate the metric value from true and false positive and negative rates."""
 805        if self.optimize:
 806            scores = []
 807            for k in self.threshold_values:
 808                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 809                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 810                scores.append(2 * recall * precision / (recall + precision + 1e-7))
 811            f1 = max(scores)
 812        else:
 813            k = self.threshold_values[0]
 814            if isinstance(k, list):
 815                k = ", ".join(map(str, k))
 816            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 817            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 818            f1 = 2 * recall * precision / (recall + precision + 1e-7)
 819        return f1
 820
 821
 822class SegmentalF1(_ClassificationMetric):
 823    """Segmental F1 score (not advised for training)."""
 824
 825    segmental = True
 826
 827    def __init__(
 828        self,
 829        num_classes: int,
 830        ignore_index: int = -100,
 831        average: str = "macro",
 832        ignored_classes: Set = None,
 833        exclusive: bool = True,
 834        iou_threshold: float = 0.5,
 835        tag_average: str = "micro",
 836        threshold_value: Union[float, List] = None,
 837    ):
 838        """Initialize the class.
 839
 840        Parameters
 841        ----------
 842        ignore_index : int, default -100
 843            the class index that indicates ignored samples
 844        average: {'macro', 'micro', 'none'}
 845            method for averaging across classes
 846        num_classes : int, optional
 847            number of classes (not necessary if main_class is not None)
 848        ignored_classes : set, optional
 849            a set of class ids to ignore in calculation
 850        exclusive: bool, default True
 851            set to False for multi-label classification tasks
 852        iou_threshold : float, default 0.5
 853            if segmental is true, intervals with IoU larger than this threshold are considered correct
 854        tag_average: {'micro', 'macro', 'none'}
 855            method for averaging across meta tags (if given)
 856        threshold_value : float | list, optional
 857            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 858            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 859            under the same index
 860
 861        """
 862        if threshold_value is None:
 863            threshold_value = 0.5
 864        super().__init__(
 865            num_classes,
 866            ignore_index,
 867            average,
 868            ignored_classes,
 869            exclusive,
 870            iou_threshold,
 871            tag_average,
 872            [threshold_value],
 873        )
 874
 875    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 876        """Calculate the metric value from true and false positive and negative rates."""
 877        if self.optimize:
 878            scores = []
 879            for k in self.threshold_values:
 880                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 881                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 882                scores.append(2 * recall * precision / (recall + precision + 1e-7))
 883            f1 = max(scores)
 884        else:
 885            k = self.threshold_values[0]
 886            if isinstance(k, list):
 887                k = ", ".join(map(str, k))
 888            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 889            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 890            f1 = 2 * recall * precision / (recall + precision + 1e-7)
 891        return f1
 892
 893
 894class Fbeta(_ClassificationMetric):
 895    """F-beta score."""
 896
 897    def __init__(
 898        self,
 899        beta: float = 1,
 900        ignore_index: int = -100,
 901        average: str = "macro",
 902        num_classes: int = None,
 903        ignored_classes: Set = None,
 904        tag_average: str = "micro",
 905        exclusive: bool = True,
 906        threshold_value: float = 0.5,
 907    ):
 908        """Initialize the class.
 909
 910        Parameters
 911        ----------
 912        beta : float, default 1
 913            the beta parameter
 914        ignore_index : int, default -100
 915            the class index that indicates ignored samples
 916        average: {'macro', 'micro', 'none'}
 917            method for averaging across classes
 918        num_classes : int, optional
 919            number of classes (not necessary if main_class is not None)
 920        ignored_classes : set, optional
 921            a set of class ids to ignore in calculation
 922        exclusive: bool, default True
 923            set to False for multi-label classification tasks
 924        tag_average: {'micro', 'macro', 'none'}
 925            method for averaging across meta tags (if given)
 926        threshold_value : float | list, optional
 927            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 928            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 929            under the same index
 930
 931        """
 932        if threshold_value is None:
 933            threshold_value = 0.5
 934        self.beta2 = beta**2
 935        super().__init__(
 936            num_classes,
 937            ignore_index,
 938            average,
 939            ignored_classes,
 940            exclusive,
 941            tag_average=tag_average,
 942            threshold_values=[threshold_value],
 943        )
 944
 945    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 946        """Calculate the metric value from true and false positive and negative rates."""
 947        if self.optimize:
 948            scores = []
 949            for k in self.threshold_values:
 950                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 951                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 952                scores.append(
 953                    (
 954                        (1 + self.beta2)
 955                        * precision
 956                        * recall
 957                        / (self.beta2 * precision + recall + 1e-7)
 958                    )
 959                )
 960            f1 = max(scores)
 961        else:
 962            k = self.threshold_values[0]
 963            if isinstance(k, list):
 964                k = ", ".join(map(str, k))
 965            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 966            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 967            f1 = (
 968                (1 + self.beta2)
 969                * precision
 970                * recall
 971                / (self.beta2 * precision + recall + 1e-7)
 972            )
 973        return f1
 974
 975
 976class SegmentalFbeta(_ClassificationMetric):
 977    """Segmental F-beta score (not advised for training)."""
 978
 979    segmental = True
 980
 981    def __init__(
 982        self,
 983        beta: float = 1,
 984        ignore_index: int = -100,
 985        average: str = "macro",
 986        num_classes: int = None,
 987        ignored_classes: Set = None,
 988        iou_threshold: float = 0.5,
 989        tag_average: str = "micro",
 990        exclusive: bool = True,
 991        threshold_value: float = 0.5,
 992    ):
 993        """Initialize the class.
 994
 995        Parameters
 996        ----------
 997        beta : float, default 1
 998            the beta parameter
 999        ignore_index : int, default -100
1000            the class index that indicates ignored samples
1001        average: {'macro', 'micro', 'none'}
1002            method for averaging across classes
1003        num_classes : int, optional
1004            number of classes (not necessary if main_class is not None)
1005        ignored_classes : set, optional
1006            a set of class ids to ignore in calculation
1007        exclusive: bool, default True
1008            set to False for multi-label classification tasks
1009        iou_threshold : float, default 0.5
1010            if segmental is true, intervals with IoU larger than this threshold are considered correct
1011        tag_average: {'micro', 'macro', 'none'}
1012            method for averaging across meta tags (if given)
1013        threshold_value : float | list, optional
1014            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
1015            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
1016            under the same index
1017
1018        """
1019        if threshold_value is None:
1020            threshold_value = 0.5
1021        self.beta2 = beta**2
1022        super().__init__(
1023            num_classes,
1024            ignore_index,
1025            average,
1026            ignored_classes,
1027            exclusive,
1028            iou_threshold=iou_threshold,
1029            tag_average=tag_average,
1030            threshold_values=[threshold_value],
1031        )
1032
1033    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1034        """Calculate the metric value from true and false positive and negative rates."""
1035        if self.optimize:
1036            scores = []
1037            for k in self.threshold_values:
1038                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1039                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1040                scores.append(
1041                    (
1042                        (1 + self.beta2)
1043                        * precision
1044                        * recall
1045                        / (self.beta2 * precision + recall + 1e-7)
1046                    )
1047                )
1048            f1 = max(scores)
1049        else:
1050            k = self.threshold_values[0]
1051            if isinstance(k, list):
1052                k = ", ".join(map(str, k))
1053            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1054            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1055            f1 = (
1056                (1 + self.beta2)
1057                * precision
1058                * recall
1059                / (self.beta2 * precision + recall + 1e-7)
1060            )
1061        return f1
1062
1063
1064class _SemiSegmentalMetric(_ClassificationMetric):
1065    def __init__(
1066        self,
1067        num_classes: int,
1068        ignore_index: int = -100,
1069        ignored_classes: Set = None,
1070        exclusive: bool = True,
1071        average: str = "macro",
1072        tag_average: str = "micro",
1073        delta: int = 0,
1074        smooth_interval: int = 0,
1075        iou_threshold_long: float = 0.5,
1076        iou_threshold_short: float = 0.5,
1077        short_length: int = 30,
1078        long_length: int = 300,
1079        threshold_values: List = None,
1080    ) -> None:
1081        """Initialize the class.
1082
1083        Parameters
1084        ----------
1085        num_classes : int
1086            the number of classes in the dataset
1087        ignore_index : int, default -100
1088            the ground truth label to ignore
1089        ignored_classes : set, optional
1090            the class indices to ignore in computation
1091        exclusive : bool, default True
1092            `False` for multi-label classification tasks
1093        average : {"macro", "micro", "none"}
1094            the method to average the results over classes
1095        tag_average : {"macro", "micro", "none"}
1096            the method to average the results over meta tags (if given)
1097        delta : int, default 0
1098            the number of frames to add to each ground truth interval before computing the intersection,
1099            see description of the class for details
1100        smooth_interval : int, default 0
1101            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1102            see description of the class for details
1103        iou_threshold_long : float, default 0.5
1104            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1105            see description of the class for details
1106        iou_threshold_short : float, default 0.5
1107            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1108            see description of the class for details
1109        short_length : int, default 30
1110            the threshold number of frames for short intervals that will have an intersection threshold of
1111            `iou_threshold_short`, see description of the class for details
1112        long_length : int, default 300
1113            the threshold number of frames for long intervals that will have an intersection threshold of
1114            `iou_threshold_long`, see description of the class for details
1115        threshold_values : list, optional
1116            the decision threshold values (cannot be defined for exclusive classification, and 0.5 be default)
1117
1118        """
1119        if smooth_interval > 0 and exclusive:
1120            warnings.warn(
1121                "Smoothing is not implemented for datasets with exclusive classification! Setting smooth_interval to 0..."
1122            )
1123        if threshold_values == [None]:
1124            threshold_values = [0.5]
1125        super().__init__(
1126            num_classes,
1127            ignore_index,
1128            average,
1129            ignored_classes,
1130            exclusive,
1131            tag_average=tag_average,
1132            threshold_values=threshold_values,
1133        )
1134        self.delta = delta
1135        self.random_sampling = False
1136        self.smooth_interval = smooth_interval
1137        self.threshold_function = self._sigmoid_threshold_function(
1138            low_threshold=iou_threshold_short,
1139            high_threshold=iou_threshold_long,
1140            low_length=short_length,
1141            high_length=long_length,
1142        )
1143
1144    def update(
1145        self,
1146        predicted: torch.Tensor,
1147        target: torch.Tensor,
1148        tags: torch.Tensor,
1149    ) -> None:
1150        """Update the internal state (with a batch).
1151
1152        Parameters
1153        ----------
1154        predicted : torch.Tensor
1155            the main prediction tensor generated by the model
1156        ssl_predicted : torch.Tensor
1157            the SSL prediction tensor generated by the model
1158        target : torch.Tensor
1159            the corresponding main target tensor
1160        ssl_target : torch.Tensor
1161            the corresponding SSL target tensor
1162        tags : torch.Tensor
1163            the tensor of meta tags (or `None`, if tags are not given)
1164
1165        """
1166        if self.exclusive:
1167            predicted = torch.max(predicted, 1)[1]
1168        predicted = torch.cat(
1169            [
1170                copy(predicted),
1171                -100
1172                * torch.ones((*predicted.shape[:-1], self.delta + 1)).to(
1173                    predicted.device
1174                ),
1175            ],
1176            dim=-1,
1177        )
1178        target = torch.cat(
1179            [
1180                copy(target),
1181                -100
1182                * torch.ones((*target.shape[:-1], self.delta + 1)).to(target.device),
1183            ],
1184            dim=-1,
1185        )
1186        if self.exclusive:
1187            predicted = predicted.flatten()
1188            target = target.flatten()
1189        else:
1190            predicted = predicted.transpose(1, 2).reshape(-1, target.shape[1])
1191            target = target.transpose(1, 2).reshape(-1, predicted.shape[1])
1192        if self.tag_average == "micro" or tags is None or tags[0] is None:
1193            tag_set = {None}
1194        else:
1195            tag_set = set(tags)
1196            self.tags.update(tag_set)
1197        if self.smooth_interval > 0:
1198            target = self._smooth(target, self.smooth_interval)
1199        for thr in self.threshold_values:
1200            key = thr
1201            for t in tag_set:
1202                if t is None:
1203                    predicted_t = deepcopy(predicted)
1204                    target_t = target
1205                else:
1206                    predicted_t = deepcopy(predicted[tags == t])
1207                    target_t = target[tags == t]
1208                if not self.exclusive:
1209                    if isinstance(thr, list):
1210                        for i, tt in enumerate(thr):
1211                            predicted_t[i, :] = (predicted_t[i, :] > tt).int()
1212                        key = ", ".join(map(str, thr))
1213                    else:
1214                        predicted_t = (predicted_t > thr).int()
1215                if self.smooth_interval > 0:
1216                    predicted_t = self._smooth(predicted_t, self.smooth_interval)
1217                for c in self.classes:
1218                    if not self.random_sampling:
1219                        if self.exclusive:
1220                            target_intervals = self._get_intervals(target_t == c)
1221                        else:
1222                            target_intervals = self._get_intervals(target_t[:, c] == 1)
1223                        target_lengths = [
1224                            end - start for start, end in target_intervals
1225                        ]
1226                        target_intervals = [
1227                            [
1228                                max(start - self.delta, 0),
1229                                min(end + self.delta, len(target_t)),
1230                            ]
1231                            for start, end in target_intervals
1232                        ]
1233                    else:
1234                        if self.exclusive:
1235                            target_arr = target_t == c
1236                        else:
1237                            target_arr = target_t[:, c] == 1
1238                        target_points = torch.where(target_arr)[0][::20]
1239                        target_intervals = []
1240                        for p in target_points:
1241                            target_intervals.append(
1242                                [
1243                                    max(0, p - self.delta),
1244                                    min(len(target_arr), p + self.delta),
1245                                ]
1246                            )
1247                        target_lengths = [
1248                            end - start for start, end in target_intervals
1249                        ]
1250                    if self.exclusive:
1251                        predicted_arr = predicted_t == c
1252                    else:
1253                        predicted_arr = predicted_t[:, c] == 1
1254                    for interval, l in zip(target_intervals, target_lengths):
1255                        intersection = torch.sum(
1256                            predicted_arr[interval[0] : interval[1]]
1257                        )
1258                        IoU = intersection / l
1259                        if IoU >= self.threshold_function(l):
1260                            self.tp[self._key(t, c)][key] += 1
1261                        else:
1262                            self.fn[self._key(t, c)][key] += 1
1263
1264                    if not self.random_sampling:
1265                        if self.exclusive:
1266                            predicted_intervals = self._get_intervals(predicted_t == c)
1267                        else:
1268                            predicted_intervals = self._get_intervals(
1269                                predicted_t[:, c] == 1
1270                            )
1271                        predicted_intervals_delta = [
1272                            [
1273                                max(start - self.delta, 0),
1274                                min(end + self.delta, len(target_t)),
1275                            ]
1276                            for start, end in predicted_intervals
1277                        ]
1278                    else:
1279                        if self.exclusive:
1280                            predicted_arr = predicted_t == c
1281                        else:
1282                            predicted_arr = predicted_t[:, c] == 1
1283                        predicted_points = torch.where(predicted_arr)[0][::20]
1284                        predicted_intervals = []
1285                        for p in predicted_points:
1286                            predicted_intervals.append(
1287                                [
1288                                    max(0, p - self.delta),
1289                                    min(len(predicted_arr), p + self.delta),
1290                                ]
1291                            )
1292                        predicted_intervals_delta = predicted_intervals
1293                    if self.exclusive:
1294                        target_arr = target_t == c
1295                        target_arr[target_t == -100] = -100
1296                    else:
1297                        target_arr = target_t[:, c]
1298                    for interval, interval_delta in zip(
1299                        predicted_intervals, predicted_intervals_delta
1300                    ):
1301                        if torch.sum(
1302                            target_arr[interval_delta[0] : interval_delta[1]] != -100
1303                        ) < 0.3 * (interval[1] - interval[0]):
1304                            continue
1305                        l = torch.sum(target_arr[interval[0] : interval[1]] != -100)
1306                        intersection = torch.sum(
1307                            target_arr[interval_delta[0] : interval_delta[1]] == 1
1308                        )
1309                        IoU = intersection / l
1310                        if IoU >= self.threshold_function(l):
1311                            self.tn[self._key(t, c)][key] += 1
1312                        else:
1313                            self.fp[self._key(t, c)][key] += 1
1314
1315
1316class SemiSegmentalRecall(_SemiSegmentalMetric):
1317    """Semi-segmental recall (not advised for training).
1318
1319    A metric in-between segmental and frame-wise recall.
1320
1321    This metric follows the following algorithm:
1322    1) smooth over too-short intervals, both in ground truth and in prediction (first remove
1323        groups of zeros shorter than `smooth_interval` and then do the same with groups of ones),
1324    2) add `delta` frames to each ground truth interval at both ends and count the number of predicted
1325        positive frames at the resulting intervals (intersection),
1326    3) calculate the threshold for each interval as
1327        `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where
1328        `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval
1329        before `delta` was added,
1330    4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`),
1331        and otherwise as false negative (`FN`),
1332    5) the final metric value is computed as `TP / (TP + FN)`.
1333    """
1334
1335    def __init__(
1336        self,
1337        num_classes: int,
1338        ignore_index: int = -100,
1339        ignored_classes: Set = None,
1340        exclusive: bool = True,
1341        average: str = "macro",
1342        tag_average: str = "micro",
1343        delta: int = 0,
1344        smooth_interval: int = 0,
1345        iou_threshold_long: float = 0.5,
1346        iou_threshold_short: float = 0.5,
1347        short_length: int = 30,
1348        long_length: int = 300,
1349        threshold_value: Union[float, List] = None,
1350    ) -> None:
1351        """Initialize the class.
1352
1353        Parameters
1354        ----------
1355        num_classes : int
1356            the number of classes in the dataset
1357        ignore_index : int, default -100
1358            the ground truth label to ignore
1359        ignored_classes : set, optional
1360            the class indices to ignore in computation
1361        exclusive : bool, default True
1362            `False` for multi-label classification tasks
1363        average : {"macro", "micro", "none"}
1364            the method to average the results over classes
1365        tag_average : {"macro", "micro", "none"}
1366            the method to average the results over meta tags (if given)
1367        delta : int, default 0
1368            the number of frames to add to each ground truth interval before computing the intersection,
1369            see description of the class for details
1370        smooth_interval : int, default 0
1371            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1372            see description of the class for details
1373        iou_threshold_long : float, default 0.5
1374            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1375            see description of the class for details
1376        iou_threshold_short : float, default 0.5
1377            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1378            see description of the class for details
1379        short_length : int, default 30
1380            the threshold number of frames for short intervals that will have an intersection threshold of
1381            `iou_threshold_short`, see description of the class for details
1382        long_length : int, default 300
1383            the threshold number of frames for long intervals that will have an intersection threshold of
1384            `iou_threshold_long`, see description of the class for details
1385        threshold_value : float | list, optional
1386            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1387
1388        """
1389        super().__init__(
1390            num_classes,
1391            ignore_index,
1392            ignored_classes,
1393            exclusive,
1394            average,
1395            tag_average,
1396            delta,
1397            smooth_interval,
1398            iou_threshold_long,
1399            iou_threshold_short,
1400            short_length,
1401            long_length,
1402            [threshold_value],
1403        )
1404
1405    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1406        """Calculate the metric value from true and false positive and negative rates."""
1407        k = self.threshold_values[0]
1408        if isinstance(k, list):
1409            k = ", ".join(map(str, k))
1410        return tp[k] / (tp[k] + fn[k] + 1e-7)
1411
1412
1413class SemiSegmentalPrecision(_SemiSegmentalMetric):
1414    """Semi-segmental precision (not advised for training).
1415
1416    A metric in-between segmental and frame-wise precision.
1417
1418    This metric follows the following algorithm:
1419    1) smooth over too-short intervals, both in ground truth and in prediction (first remove
1420        groups of zeros shorter than `smooth_interval` and then do the same with groups of ones),
1421    2) add `delta` frames to each predicted interval at both ends and count the number of ground truth
1422        positive frames at the resulting intervals (intersection),
1423    3) calculate the threshold for each interval as
1424        `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where
1425        `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval
1426        before `delta` was added,
1427    4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`),
1428        and otherwise as false positive (`FP`),
1429    5) the final metric value is computed as `TP / (TP + FP)`.
1430
1431    """
1432    def __init__(
1433        self,
1434        num_classes: int,
1435        ignore_index: int = -100,
1436        ignored_classes: Set = None,
1437        exclusive: bool = True,
1438        average: str = "macro",
1439        tag_average: str = "micro",
1440        delta: int = 0,
1441        smooth_interval: int = 0,
1442        iou_threshold_long: float = 0.5,
1443        iou_threshold_short: float = 0.5,
1444        short_length: int = 30,
1445        long_length: int = 300,
1446        threshold_value: Union[float, List] = None,
1447    ) -> None:
1448        """Initialize the class.
1449
1450        Parameters
1451        ----------
1452        num_classes : int
1453            the number of classes in the dataset
1454        ignore_index : int, default -100
1455            the ground truth label to ignore
1456        ignored_classes : set, optional
1457            the class indices to ignore in computation
1458        exclusive : bool, default True
1459            `False` for multi-label classification tasks
1460        average : {"macro", "micro", "none"}
1461            the method to average the results over classes
1462        tag_average : {"macro", "micro", "none"}
1463            the method to average the results over meta tags (if given)
1464        delta : int, default 0
1465            the number of frames to add to each ground truth interval before computing the intersection,
1466            see description of the class for details
1467        smooth_interval : int, default 0
1468            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1469            see description of the class for details
1470        iou_threshold_long : float, default 0.5
1471            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1472            see description of the class for details
1473        iou_threshold_short : float, default 0.5
1474            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1475            see description of the class for details
1476        short_length : int, default 30
1477            the threshold number of frames for short intervals that will have an intersection threshold of
1478            `iou_threshold_short`, see description of the class for details
1479        long_length : int, default 300
1480            the threshold number of frames for long intervals that will have an intersection threshold of
1481            `iou_threshold_long`, see description of the class for details
1482        threshold_value : float | list, optional
1483            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1484
1485        """
1486        super().__init__(
1487            num_classes,
1488            ignore_index,
1489            ignored_classes,
1490            exclusive,
1491            average,
1492            tag_average,
1493            delta,
1494            smooth_interval,
1495            iou_threshold_long,
1496            iou_threshold_short,
1497            short_length,
1498            long_length,
1499            [threshold_value],
1500        )
1501
1502    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1503        """Calculate the metric value from true and false positive and negative rates."""
1504        k = self.threshold_values[0]
1505        if isinstance(k, list):
1506            k = ", ".join(map(str, k))
1507        return tn[k] / (tn[k] + fp[k] + 1e-7)
1508
1509
1510class SemiSegmentalF1(_SemiSegmentalMetric):
1511    """The F1 score for semi-segmental recall and precision (not advised for training)."""
1512
1513    def __init__(
1514        self,
1515        num_classes: int,
1516        ignore_index: int = -100,
1517        ignored_classes: Set = None,
1518        exclusive: bool = True,
1519        average: str = "macro",
1520        tag_average: str = "micro",
1521        delta: int = 0,
1522        smooth_interval: int = 0,
1523        iou_threshold_long: float = 0.5,
1524        iou_threshold_short: float = 0.5,
1525        short_length: int = 30,
1526        long_length: int = 300,
1527        threshold_value: Union[float, List] = None,
1528    ) -> None:
1529        """Initialize the class.
1530
1531        Parameters
1532        ----------
1533        num_classes : int
1534            the number of classes in the dataset
1535        ignore_index : int, default -100
1536            the ground truth label to ignore
1537        ignored_classes : set, optional
1538            the class indices to ignore in computation
1539        exclusive : bool, default True
1540            `False` for multi-label classification tasks
1541        average : {"macro", "micro", "none"}
1542            the method to average the results over classes
1543        tag_average : {"macro", "micro", "none"}
1544            the method to average the results over meta tags (if given)
1545        delta : int, default 0
1546            the number of frames to add to each ground truth interval before computing the intersection,
1547            see description of the class for details
1548        smooth_interval : int, default 0
1549            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1550            see description of the class for details
1551        iou_threshold_long : float, default 0.5
1552            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1553            see description of the class for details
1554        iou_threshold_short : float, default 0.5
1555            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1556            see description of the class for details
1557        short_length : int, default 30
1558            the threshold number of frames for short intervals that will have an intersection threshold of
1559            `iou_threshold_short`, see description of the class for details
1560        long_length : int, default 300
1561            the threshold number of frames for long intervals that will have an intersection threshold of
1562            `iou_threshold_long`, see description of the class for details
1563        threshold_value : float | list, optional
1564            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1565
1566        """
1567        super().__init__(
1568            num_classes,
1569            ignore_index,
1570            ignored_classes,
1571            exclusive,
1572            average,
1573            tag_average,
1574            delta,
1575            smooth_interval,
1576            iou_threshold_long,
1577            iou_threshold_short,
1578            short_length,
1579            long_length,
1580            [threshold_value],
1581        )
1582
1583    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1584        """Calculate the metric value from true and false positive and negative rates."""
1585        if self.optimize:
1586            scores = []
1587            for k in self.threshold_values:
1588                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1589                precision = tn[k] / (tn[k] + fp[k] + 1e-7)
1590                scores.append(2 * recall * precision / (recall + precision + 1e-7))
1591            f1 = max(scores)
1592        else:
1593            k = self.threshold_values[0]
1594            if isinstance(k, list):
1595                k = ", ".join(map(str, k))
1596            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1597            precision = tn[k] / (tn[k] + fp[k] + 1e-7)
1598            f1 = 2 * recall * precision / (recall + precision + 1e-7)
1599        return f1
1600
1601
1602class SemiSegmentalPR_AUC(_SemiSegmentalMetric):
1603    """The area under the precision-recall curve for semi-segmental metrics (not advised for training)."""
1604
1605    def __init__(
1606        self,
1607        num_classes: int,
1608        ignore_index: int = -100,
1609        ignored_classes: Set = None,
1610        exclusive: bool = True,
1611        average: str = "macro",
1612        tag_average: str = "micro",
1613        delta: int = 0,
1614        smooth_interval: int = 0,
1615        iou_threshold_long: float = 0.5,
1616        iou_threshold_short: float = 0.5,
1617        short_length: int = 30,
1618        long_length: int = 300,
1619        threshold_step: float = 0.1,
1620    ) -> None:
1621        super().__init__(
1622            num_classes,
1623            ignore_index,
1624            ignored_classes,
1625            exclusive,
1626            average,
1627            tag_average,
1628            delta,
1629            smooth_interval,
1630            iou_threshold_long,
1631            iou_threshold_short,
1632            short_length,
1633            long_length,
1634            list(np.arange(0, 1, threshold_step)),
1635        )
1636
1637    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1638        """Calculate the metric value from true and false positive and negative rates."""
1639        precisions = []
1640        recalls = []
1641        for k in sorted(self.threshold_values):
1642            precisions.append(tn[k] / (tn[k] + fp[k] + 1e-7))
1643            recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7))
1644        return metrics.auc(x=recalls, y=precisions)
1645
1646
1647class Accuracy(Metric):
1648    """Accuracy."""
1649
1650    def __init__(self, ignore_index=-100):
1651        """Initialize the class.
1652
1653        Parameters
1654        ----------
1655        ignore_index: int
1656            the class index that indicates ignored sample
1657
1658        """
1659        super().__init__()
1660        self.ignore_index = ignore_index
1661
1662    def reset(self) -> None:
1663        """Reset the internal state (at the beginning of an epoch)."""
1664        self.total = 0
1665        self.correct = 0
1666
1667    def calculate(self) -> float:
1668        """Calculate the metric value.
1669
1670        Returns
1671        -------
1672        metric : float
1673            metric value
1674
1675        """
1676        return self.correct / (self.total + 1e-7)
1677
1678    def update(
1679        self,
1680        predicted: torch.Tensor,
1681        target: torch.Tensor,
1682        tags: torch.Tensor = None,
1683    ) -> None:
1684        """Update the internal state (with a batch).
1685
1686        Parameters
1687        ----------
1688        predicted : torch.Tensor
1689            the main prediction tensor generated by the model
1690        ssl_predicted : torch.Tensor
1691            the SSL prediction tensor generated by the model
1692        target : torch.Tensor
1693            the corresponding main target tensor
1694        ssl_target : torch.Tensor
1695            the corresponding SSL target tensor
1696        tags : torch.Tensor
1697            the tensor of meta tags (or `None`, if tags are not given)
1698
1699        """
1700        mask = target != self.ignore_index
1701        self.total += torch.sum(mask)
1702        self.correct += torch.sum((target == predicted)[mask])
1703
1704
1705class Count(Metric):
1706    """Fraction of samples labeled by the model as a class."""
1707
1708    def __init__(self, classes: Set, exclusive: bool = True):
1709        """Initialize the class.
1710
1711        Parameters
1712        ----------
1713        classes : set
1714            the set of classes to count
1715        exclusive: bool, default True
1716            set to False for multi-label classification tasks
1717
1718        """
1719        super().__init__()
1720        self.classes = classes
1721        self.exclusive = exclusive
1722
1723    def reset(self) -> None:
1724        """Reset the internal state (at the beginning of an epoch)."""
1725        self.count = defaultdict(lambda: 0)
1726        self.total = 0
1727
1728    def update(
1729        self,
1730        predicted: torch.Tensor,
1731        target: torch.Tensor,
1732        tags: torch.Tensor,
1733    ) -> None:
1734        """Update the internal state (with a batch).
1735
1736        Parameters
1737        ----------
1738        predicted : torch.Tensor
1739            the main prediction tensor generated by the model
1740        ssl_predicted : torch.Tensor
1741            the SSL prediction tensor generated by the model
1742        target : torch.Tensor
1743            the corresponding main target tensor
1744        ssl_target : torch.Tensor
1745            the corresponding SSL target tensor
1746        tags : torch.Tensor
1747            the tensor of meta tags (or `None`, if tags are not given)
1748
1749        """
1750        if self.exclusive:
1751            for c in self.classes:
1752                self.count[c] += torch.sum(predicted == c)
1753            self.total += torch.numel(predicted)
1754        else:
1755            for c in self.classes:
1756                self.count[c] += torch.sum(predicted[:, c, :] == 1)
1757            self.total += torch.numel(predicted[:, 0, :])
1758
1759    def calculate(self) -> Dict:
1760        """Calculate the metric (at the end of an epoch).
1761
1762        Returns
1763        -------
1764        result : dict
1765            a dictionary where the keys are class indices and the values are class metric values
1766
1767        """
1768        for c in self.classes:
1769            self.count[c] = self.count[c] / (self.total + 1e-7)
1770        return dict(self.count)
1771
1772
1773class EditDistance(Metric):
1774    """Edit distance (not advised for training).
1775
1776    Normalized by the length of the sequences.
1777    """
1778
1779    def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None:
1780        """Initialize the class.
1781
1782        Parameters
1783        ----------
1784        ignore_index : int, default -100
1785            the class index that indicates samples that should be ignored
1786        exclusive : bool, default True
1787            set to False for multi-label classification tasks
1788
1789        """
1790        super().__init__()
1791        self.ignore_index = ignore_index
1792        self.exclusive = exclusive
1793
1794    def reset(self) -> None:
1795        """Reset the internal state (at the beginning of an epoch)."""
1796        self.edit_distance = 0
1797        self.total = 0
1798
1799    def update(
1800        self,
1801        predicted: torch.Tensor,
1802        target: torch.Tensor,
1803        tags: torch.Tensor,
1804    ) -> None:
1805        """Update the internal state (with a batch).
1806
1807        Parameters
1808        ----------
1809        predicted : torch.Tensor
1810            the main prediction tensor generated by the model
1811        ssl_predicted : torch.Tensor
1812            the SSL prediction tensor generated by the model
1813        target : torch.Tensor
1814            the corresponding main target tensor
1815        ssl_target : torch.Tensor
1816            the corresponding SSL target tensor
1817        tags : torch.Tensor
1818            the tensor of meta tags (or `None`, if tags are not given)
1819
1820        """
1821        mask = target != self.ignore_index
1822        self.total += torch.sum(mask)
1823        if self.exclusive:
1824            predicted = predicted[mask].flatten()
1825            target = target[mask].flatten()
1826            self.edit_distance += editdistance.eval(
1827                predicted.detach().cpu().numpy(), target.detach().cpu().numpy()
1828            )
1829        else:
1830            for c in range(target.shape[1]):
1831                predicted_class = predicted[:, c, :][mask[:, c, :]].flatten()
1832                target_class = target[:, c, :][mask[:, c, :]].flatten()
1833                self.edit_distance += editdistance.eval(
1834                    predicted_class.detach().cpu().tolist(),
1835                    target_class.detach().cpu().tolist(),
1836                )
1837
1838    def _is_equal(self, a, b):
1839        """Compare while ignoring samples marked with ignore_index."""
1840        if self.ignore_index in [a, b] or a == b:
1841            return True
1842        else:
1843            return False
1844
1845    def calculate(self) -> float:
1846        """Calculate the metric (at the end of an epoch).
1847
1848        Returns
1849        -------
1850        result : float
1851            the metric value
1852
1853        """
1854        return self.edit_distance / (self.total + 1e-7)
1855
1856
1857class mAP(Metric):
1858    """Mean average precision (segmental) (not advised for training)."""
1859
1860    needs_raw_data = True
1861
1862    def __init__(
1863        self,
1864        exclusive,
1865        num_classes,
1866        average="macro",
1867        iou_threshold=0.5,
1868        threshold_value=0.5,
1869        ignored_classes=None,
1870    ):
1871        """Initialize the class.
1872
1873        Parameters
1874        ----------
1875        exclusive : bool
1876            set to False for multi-label classification tasks
1877        num_classes : int
1878            the number of classes
1879        average : {"macro", "micro", "none"}
1880            the type of averaging to perform
1881        iou_threshold : float, default 0.5
1882            the IoU threshold for matching
1883        threshold_value : float, default 0.5
1884            the threshold value for binarization
1885        ignored_classes : list, default None
1886            the list of classes to ignore
1887
1888        """
1889        if ignored_classes is None:
1890            ignored_classes = []
1891        self.average = average
1892        self.iou_threshold = iou_threshold
1893        self.threshold = threshold_value
1894        self.exclusive = exclusive
1895        self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes]
1896        super().__init__()
1897
1898    def match(self, lst, ratio, ground):
1899        """Given a list of proposals, match them to the ground truth boxes."""
1900        lst = sorted(lst, key=lambda x: x[2])
1901
1902        def overlap(prop, ground):
1903            s_p, e_p, _ = prop
1904            s_g, e_g, _ = ground
1905            return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g))
1906
1907        cos_map = [-1 for x in range(len(lst))]
1908        count_map = [0 for x in range(len(ground))]
1909
1910        for x in range(len(lst)):
1911            for y in range(len(ground)):
1912                if overlap(lst[x], ground[y]) < ratio:
1913                    continue
1914                if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]):
1915                    continue
1916                cos_map[x] = y
1917            if cos_map[x] != -1:
1918                count_map[cos_map[x]] += 1
1919        positive = sum([(x > 0) for x in count_map])
1920        return cos_map, count_map, positive, [x[2] for x in lst]
1921
1922    def reset(self) -> None:
1923        """Reset the internal state (at the beginning of an epoch)."""
1924        self.count_map = defaultdict(lambda: [])
1925        self.positive = defaultdict(lambda: 0)
1926        self.cos_map = defaultdict(lambda: [])
1927        self.confidence = defaultdict(lambda: [])
1928
1929    def calc_pr(self, positive, proposal, ground):
1930        """Get precision."""
1931        if proposal == 0:
1932            return 0, 0
1933        if ground == 0:
1934            return 0, 0
1935        return (1.0 * positive) / proposal, (1.0 * positive) / ground
1936
1937    def calculate(self) -> Union[float, Dict]:
1938        """Calculate the metric (at the end of an epoch)."""
1939        if self.average == "micro":
1940            confidence = []
1941            count_map = []
1942            cos_map = []
1943            positive = sum(self.positive.values())
1944            for key in self.count_map.keys():
1945                confidence += self.confidence[key]
1946                cos_map += list(np.array(self.cos_map[key]) + len(count_map))
1947                count_map += self.count_map[key]
1948            return self.ap(cos_map, count_map, positive, confidence)
1949        results = {
1950            key: self.ap(
1951                self.cos_map[key],
1952                self.count_map[key],
1953                self.positive[key],
1954                self.confidence[key],
1955            )
1956            for key in self.count_map.keys()
1957        }
1958        if self.average == "none":
1959            return results
1960        else:
1961            return float(np.mean(list(results.values())))
1962
1963    def ap(self, cos_map, count_map, positive, confidence):
1964        """Compute average precision."""
1965        indices = np.argsort(confidence)
1966        cos_map = list(np.array(cos_map)[indices])
1967        score = 0
1968        number_proposal = len(cos_map)
1969        number_ground = len(count_map)
1970        old_precision, old_recall = self.calc_pr(
1971            positive, number_proposal, number_ground
1972        )
1973
1974        for x in range(len(cos_map)):
1975            number_proposal -= 1
1976            if cos_map[x] == -1:
1977                continue
1978            count_map[cos_map[x]] -= 1
1979            if count_map[cos_map[x]] == 0:
1980                positive -= 1
1981
1982            precision, recall = self.calc_pr(positive, number_proposal, number_ground)
1983            if precision > old_precision:
1984                old_precision = precision
1985            score += old_precision * (old_recall - recall)
1986            old_recall = recall
1987        return score
1988
1989    def _get_intervals(
1990        self, tensor: torch.Tensor, probability: torch.Tensor = None
1991    ) -> Union[Tuple, torch.Tensor]:
1992        """Get `True` group beginning and end indices from a boolean tensor and average probability over these intervals."""
1993        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
1994        true_indices = torch.where(output)[0]
1995        starts = torch.tensor(
1996            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
1997        )
1998        ends = torch.tensor(
1999            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
2000        )
2001        confidence = torch.tensor(
2002            [probability[indices == i].mean() for i in true_indices]
2003        )
2004        return torch.stack([starts, ends, confidence]).T
2005
2006    def update(
2007        self,
2008        predicted: torch.Tensor,
2009        target: torch.Tensor,
2010        tags: torch.Tensor,
2011    ) -> None:
2012        """Update the state (at the end of each batch)."""
2013        predicted = torch.cat(
2014            [
2015                copy(predicted),
2016                -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device),
2017            ],
2018            dim=-1,
2019        )
2020        target = torch.cat(
2021            [
2022                copy(target),
2023                -100 * torch.ones((*target.shape[:-1], 1)).to(target.device),
2024            ],
2025            dim=-1,
2026        )
2027        num_classes = predicted.shape[1]
2028        predicted = predicted.transpose(1, 2).reshape(-1, num_classes)
2029        if self.exclusive:
2030            target = target.flatten()
2031        else:
2032            target = target.transpose(1, 2).reshape(-1, num_classes)
2033        probability = copy(predicted)
2034        if not self.exclusive:
2035            predicted = (predicted > self.threshold).int()
2036        else:
2037            predicted = torch.max(predicted, 1)[1]
2038        for c in self.classes:
2039            if self.exclusive:
2040                predicted_intervals = self._get_intervals(
2041                    predicted == c, probability=probability[:, c]
2042                )
2043                target_intervals = self._get_intervals(
2044                    target == c, probability=probability[:, c]
2045                )
2046            else:
2047                predicted_intervals = self._get_intervals(
2048                    predicted[:, c] == 1, probability=probability[:, c]
2049                )
2050                target_intervals = self._get_intervals(
2051                    target[:, c] == 1, probability=probability[:, c]
2052                )
2053            cos_map, count_map, positive, confidence = self.match(
2054                predicted_intervals, self.iou_threshold, target_intervals
2055            )
2056            cos_map = np.array(cos_map)
2057            cos_map[cos_map != -1] += len(self.count_map[c])
2058            self.cos_map[c] += list(cos_map)
2059            self.count_map[c] += count_map
2060            self.confidence[c] += confidence
2061            self.positive[c] += positive
class PR_AUC(_ClassificationMetric):
464class PR_AUC(_ClassificationMetric):
465    """Area under precision-recall curve (not advised for training)."""
466
467    def __init__(
468        self,
469        num_classes: int,
470        ignore_index: int = -100,
471        average: str = "macro",
472        ignored_classes: Set = None,
473        exclusive: bool = False,
474        tag_average: str = "micro",
475        threshold_step: float = 0.1,
476    ):
477        """Initialize the metric.
478
479        Parameters
480        ----------
481        ignore_index : int, default -100
482            the class index that indicates ignored samples
483        average: {'macro', 'micro', 'none'}
484            method for averaging across classes
485        num_classes : int, optional
486            number of classes (not necessary if main_class is not None)
487        ignored_classes : set, optional
488            a set of class ids to ignore in calculation
489        exclusive: bool, default True
490            set to False for multi-label classification tasks
491        tag_average: {'micro', 'macro', 'none'}
492            method for averaging across meta tags (if given)
493        threshold_step : float, default 0.1
494            the decision threshold step
495
496        """
497        if exclusive:
498            raise ValueError(
499                "The PR-AUC metric is not implemented for exclusive classification!"
500            )
501        super().__init__(
502            num_classes,
503            ignore_index,
504            average,
505            ignored_classes,
506            exclusive,
507            tag_average=tag_average,
508            threshold_values=list(np.arange(0, 1, threshold_step)),
509        )
510
511    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
512        """Calculate the metric value from true and false positive and negative rates."""
513        precisions = []
514        recalls = []
515        for k in sorted(self.threshold_values):
516            precisions.append(tp[k] / (tp[k] + fp[k] + 1e-7))
517            recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7))
518        return metrics.auc(x=recalls, y=precisions)

Area under precision-recall curve (not advised for training).

PR_AUC( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = False, tag_average: str = 'micro', threshold_step: float = 0.1)
467    def __init__(
468        self,
469        num_classes: int,
470        ignore_index: int = -100,
471        average: str = "macro",
472        ignored_classes: Set = None,
473        exclusive: bool = False,
474        tag_average: str = "micro",
475        threshold_step: float = 0.1,
476    ):
477        """Initialize the metric.
478
479        Parameters
480        ----------
481        ignore_index : int, default -100
482            the class index that indicates ignored samples
483        average: {'macro', 'micro', 'none'}
484            method for averaging across classes
485        num_classes : int, optional
486            number of classes (not necessary if main_class is not None)
487        ignored_classes : set, optional
488            a set of class ids to ignore in calculation
489        exclusive: bool, default True
490            set to False for multi-label classification tasks
491        tag_average: {'micro', 'macro', 'none'}
492            method for averaging across meta tags (if given)
493        threshold_step : float, default 0.1
494            the decision threshold step
495
496        """
497        if exclusive:
498            raise ValueError(
499                "The PR-AUC metric is not implemented for exclusive classification!"
500            )
501        super().__init__(
502            num_classes,
503            ignore_index,
504            average,
505            ignored_classes,
506            exclusive,
507            tag_average=tag_average,
508            threshold_values=list(np.arange(0, 1, threshold_step)),
509        )

Initialize the metric.

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_step : float, default 0.1 the decision threshold step

class Precision(_ClassificationMetric):
521class Precision(_ClassificationMetric):
522    """Precision."""
523
524    def __init__(
525        self,
526        num_classes: int,
527        ignore_index: int = -100,
528        average: str = "macro",
529        ignored_classes: Set = None,
530        exclusive: bool = True,
531        tag_average: str = "micro",
532        threshold_value: Union[float, List] = None,
533    ):
534        """Initialize the class.
535
536        Parameters
537        ----------
538        ignore_index : int, default -100
539            the class index that indicates ignored samples
540        average: {'macro', 'micro', 'none'}
541            method for averaging across classes
542        num_classes : int, optional
543            number of classes (not necessary if main_class is not None)
544        ignored_classes : set, optional
545            a set of class ids to ignore in calculation
546        exclusive: bool, default True
547            set to False for multi-label classification tasks
548        tag_average: {'micro', 'macro', 'none'}
549            method for averaging across meta tags (if given)
550        threshold_value : float | list, optional
551            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
552            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
553            under the same index
554
555        """
556        if threshold_value is None:
557            threshold_value = 0.5
558        super().__init__(
559            num_classes,
560            ignore_index,
561            average,
562            ignored_classes,
563            exclusive,
564            tag_average=tag_average,
565            threshold_values=[threshold_value],
566        )
567
568    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
569        """Calculate the metric value from true and false positive and negative rates."""
570        k = self.threshold_values[0]
571        if isinstance(k, list):
572            k = ", ".join(map(str, k))
573        return tp[k] / (tp[k] + fp[k] + 1e-7)

Precision.

Precision( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
524    def __init__(
525        self,
526        num_classes: int,
527        ignore_index: int = -100,
528        average: str = "macro",
529        ignored_classes: Set = None,
530        exclusive: bool = True,
531        tag_average: str = "micro",
532        threshold_value: Union[float, List] = None,
533    ):
534        """Initialize the class.
535
536        Parameters
537        ----------
538        ignore_index : int, default -100
539            the class index that indicates ignored samples
540        average: {'macro', 'micro', 'none'}
541            method for averaging across classes
542        num_classes : int, optional
543            number of classes (not necessary if main_class is not None)
544        ignored_classes : set, optional
545            a set of class ids to ignore in calculation
546        exclusive: bool, default True
547            set to False for multi-label classification tasks
548        tag_average: {'micro', 'macro', 'none'}
549            method for averaging across meta tags (if given)
550        threshold_value : float | list, optional
551            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
552            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
553            under the same index
554
555        """
556        if threshold_value is None:
557            threshold_value = 0.5
558        super().__init__(
559            num_classes,
560            ignore_index,
561            average,
562            ignored_classes,
563            exclusive,
564            tag_average=tag_average,
565            threshold_values=[threshold_value],
566        )

Initialize the class.

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

class SegmentalPrecision(_ClassificationMetric):
576class SegmentalPrecision(_ClassificationMetric):
577    """Segmental precision (not advised for training)."""
578
579    segmental = True
580
581    def __init__(
582        self,
583        num_classes: int,
584        ignore_index: int = -100,
585        average: str = "macro",
586        ignored_classes: Set = None,
587        exclusive: bool = True,
588        iou_threshold: float = 0.5,
589        tag_average: str = "micro",
590        threshold_value: Union[float, List] = None,
591    ):
592        """Initialize the class.
593
594        Parameters
595        ----------
596        ignore_index : int, default -100
597            the class index that indicates ignored samples
598        average: {'macro', 'micro', 'none'}
599            method for averaging across classes
600        num_classes : int, optional
601            number of classes (not necessary if main_class is not None)
602        ignored_classes : set, optional
603            a set of class ids to ignore in calculation
604        exclusive: bool, default True
605            set to False for multi-label classification tasks
606        iou_threshold : float, default 0.5
607            if segmental is true, intervals with IoU larger than this threshold are considered correct
608        tag_average: {'micro', 'macro', 'none'}
609            method for averaging across meta tags (if given)
610        threshold_value : float | list, optional
611            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
612            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
613            under the same index
614
615        """
616        if threshold_value is None:
617            threshold_value = 0.5
618        super().__init__(
619            num_classes,
620            ignore_index,
621            average,
622            ignored_classes,
623            exclusive,
624            iou_threshold,
625            tag_average,
626            [threshold_value],
627        )
628
629    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
630        """Calculate the metric value from true and false positive and negative rates."""
631        k = self.threshold_values[0]
632        if isinstance(k, list):
633            k = ", ".join(map(str, k))
634        return tp[k] / (tp[k] + fp[k] + 1e-7)

Segmental precision (not advised for training).

SegmentalPrecision( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, iou_threshold: float = 0.5, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
581    def __init__(
582        self,
583        num_classes: int,
584        ignore_index: int = -100,
585        average: str = "macro",
586        ignored_classes: Set = None,
587        exclusive: bool = True,
588        iou_threshold: float = 0.5,
589        tag_average: str = "micro",
590        threshold_value: Union[float, List] = None,
591    ):
592        """Initialize the class.
593
594        Parameters
595        ----------
596        ignore_index : int, default -100
597            the class index that indicates ignored samples
598        average: {'macro', 'micro', 'none'}
599            method for averaging across classes
600        num_classes : int, optional
601            number of classes (not necessary if main_class is not None)
602        ignored_classes : set, optional
603            a set of class ids to ignore in calculation
604        exclusive: bool, default True
605            set to False for multi-label classification tasks
606        iou_threshold : float, default 0.5
607            if segmental is true, intervals with IoU larger than this threshold are considered correct
608        tag_average: {'micro', 'macro', 'none'}
609            method for averaging across meta tags (if given)
610        threshold_value : float | list, optional
611            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
612            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
613            under the same index
614
615        """
616        if threshold_value is None:
617            threshold_value = 0.5
618        super().__init__(
619            num_classes,
620            ignore_index,
621            average,
622            ignored_classes,
623            exclusive,
624            iou_threshold,
625            tag_average,
626            [threshold_value],
627        )

Initialize the class.

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks iou_threshold : float, default 0.5 if segmental is true, intervals with IoU larger than this threshold are considered correct tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

segmental = True

If True, the metric will be calculated over segments; otherwise over frames.

class Recall(_ClassificationMetric):
637class Recall(_ClassificationMetric):
638    """Recall."""
639
640    def __init__(
641        self,
642        num_classes: int,
643        ignore_index: int = -100,
644        average: str = "macro",
645        ignored_classes: Set = None,
646        exclusive: bool = True,
647        tag_average: str = "micro",
648        threshold_value: Union[float, List] = None,
649    ):
650        """Initialize the class.
651
652        Parameters
653        ----------
654        ignore_index : int, default -100
655            the class index that indicates ignored samples
656        average: {'macro', 'micro', 'none'}
657            method for averaging across classes
658        num_classes : int, optional
659            number of classes (not necessary if main_class is not None)
660        ignored_classes : set, optional
661            a set of class ids to ignore in calculation
662        exclusive: bool, default True
663            set to False for multi-label classification tasks
664        tag_average: {'micro', 'macro', 'none'}
665            method for averaging across meta tags (if given)
666        threshold_value : float | list, optional
667            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
668            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
669            under the same index
670
671        """
672        if threshold_value is None:
673            threshold_value = 0.5
674        super().__init__(
675            num_classes,
676            ignore_index,
677            average,
678            ignored_classes,
679            exclusive,
680            tag_average=tag_average,
681            threshold_values=[threshold_value],
682        )
683
684    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
685        """Calculate the metric value from true and false positive and negative rates."""
686        k = self.threshold_values[0]
687        if isinstance(k, list):
688            k = ", ".join(map(str, k))
689        return tp[k] / (tp[k] + fn[k] + 1e-7)

Recall.

Recall( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
640    def __init__(
641        self,
642        num_classes: int,
643        ignore_index: int = -100,
644        average: str = "macro",
645        ignored_classes: Set = None,
646        exclusive: bool = True,
647        tag_average: str = "micro",
648        threshold_value: Union[float, List] = None,
649    ):
650        """Initialize the class.
651
652        Parameters
653        ----------
654        ignore_index : int, default -100
655            the class index that indicates ignored samples
656        average: {'macro', 'micro', 'none'}
657            method for averaging across classes
658        num_classes : int, optional
659            number of classes (not necessary if main_class is not None)
660        ignored_classes : set, optional
661            a set of class ids to ignore in calculation
662        exclusive: bool, default True
663            set to False for multi-label classification tasks
664        tag_average: {'micro', 'macro', 'none'}
665            method for averaging across meta tags (if given)
666        threshold_value : float | list, optional
667            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
668            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
669            under the same index
670
671        """
672        if threshold_value is None:
673            threshold_value = 0.5
674        super().__init__(
675            num_classes,
676            ignore_index,
677            average,
678            ignored_classes,
679            exclusive,
680            tag_average=tag_average,
681            threshold_values=[threshold_value],
682        )

Initialize the class.

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

class SegmentalRecall(_ClassificationMetric):
692class SegmentalRecall(_ClassificationMetric):
693    """Segmental recall (not advised for training)."""
694
695    segmental = True
696
697    def __init__(
698        self,
699        num_classes: int,
700        ignore_index: int = -100,
701        average: str = "macro",
702        ignored_classes: Set = None,
703        exclusive: bool = True,
704        iou_threshold: float = 0.5,
705        tag_average: str = "micro",
706        threshold_value: Union[float, List] = None,
707    ):
708        """Initialize the class.
709
710        Parameters
711        ----------
712        ignore_index : int, default -100
713            the class index that indicates ignored samples
714        average: {'macro', 'micro', 'none'}
715            method for averaging across classes
716        num_classes : int, optional
717            number of classes (not necessary if main_class is not None)
718        ignored_classes : set, optional
719            a set of class ids to ignore in calculation
720        exclusive: bool, default True
721            set to False for multi-label classification tasks
722        iou_threshold : float, default 0.5
723            if segmental is true, intervals with IoU larger than this threshold are considered correct
724        tag_average: {'micro', 'macro', 'none'}
725            method for averaging across meta tags (if given)
726        threshold_value : float | list, optional
727            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
728            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
729            under the same index
730
731        """
732        if threshold_value is None:
733            threshold_value = 0.5
734        super().__init__(
735            num_classes,
736            ignore_index,
737            average,
738            ignored_classes,
739            exclusive,
740            iou_threshold,
741            tag_average,
742            [threshold_value],
743        )
744
745    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
746        """Calculate the metric value from true and false positive and negative rates."""
747        k = self.threshold_values[0]
748        if isinstance(k, list):
749            k = ", ".join(map(str, k))
750        return tp[k] / (tp[k] + fn[k] + 1e-7)

Segmental recall (not advised for training).

SegmentalRecall( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, iou_threshold: float = 0.5, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
697    def __init__(
698        self,
699        num_classes: int,
700        ignore_index: int = -100,
701        average: str = "macro",
702        ignored_classes: Set = None,
703        exclusive: bool = True,
704        iou_threshold: float = 0.5,
705        tag_average: str = "micro",
706        threshold_value: Union[float, List] = None,
707    ):
708        """Initialize the class.
709
710        Parameters
711        ----------
712        ignore_index : int, default -100
713            the class index that indicates ignored samples
714        average: {'macro', 'micro', 'none'}
715            method for averaging across classes
716        num_classes : int, optional
717            number of classes (not necessary if main_class is not None)
718        ignored_classes : set, optional
719            a set of class ids to ignore in calculation
720        exclusive: bool, default True
721            set to False for multi-label classification tasks
722        iou_threshold : float, default 0.5
723            if segmental is true, intervals with IoU larger than this threshold are considered correct
724        tag_average: {'micro', 'macro', 'none'}
725            method for averaging across meta tags (if given)
726        threshold_value : float | list, optional
727            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
728            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
729            under the same index
730
731        """
732        if threshold_value is None:
733            threshold_value = 0.5
734        super().__init__(
735            num_classes,
736            ignore_index,
737            average,
738            ignored_classes,
739            exclusive,
740            iou_threshold,
741            tag_average,
742            [threshold_value],
743        )

Initialize the class.

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks iou_threshold : float, default 0.5 if segmental is true, intervals with IoU larger than this threshold are considered correct tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

segmental = True

If True, the metric will be calculated over segments; otherwise over frames.

class F1(_ClassificationMetric):
753class F1(_ClassificationMetric):
754    """F1 score."""
755
756    def __init__(
757        self,
758        num_classes: int,
759        ignore_index: int = -100,
760        average: str = "macro",
761        ignored_classes: Set = None,
762        exclusive: bool = True,
763        tag_average: str = "micro",
764        threshold_value: Union[float, List] = None,
765        integration_interval: int = 0,
766    ):
767        """Initialize the class.
768
769        Parameters
770        ----------
771        ignore_index : int, default -100
772            the class index that indicates ignored samples
773        average: {'macro', 'micro', 'none'}
774            method for averaging across classes
775        num_classes : int, optional
776            number of classes (not necessary if main_class is not None)
777        ignored_classes : set, optional
778            a set of class ids to ignore in calculation
779        exclusive: bool, default True
780            set to False for multi-label classification tasks
781        tag_average: {'micro', 'macro', 'none'}
782            method for averaging across meta tags (if given)
783        threshold_value : float | list, optional
784            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
785            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
786            under the same index
787        integration_interval : int, default 0
788            the number of frames to integrate over (0 means no integration)
789
790        """
791        if threshold_value is None:
792            threshold_value = 0.5
793        super().__init__(
794            num_classes,
795            ignore_index,
796            average,
797            ignored_classes,
798            exclusive,
799            tag_average=tag_average,
800            threshold_values=[threshold_value],
801            integration_interval=integration_interval,
802        )
803
804    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
805        """Calculate the metric value from true and false positive and negative rates."""
806        if self.optimize:
807            scores = []
808            for k in self.threshold_values:
809                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
810                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
811                scores.append(2 * recall * precision / (recall + precision + 1e-7))
812            f1 = max(scores)
813        else:
814            k = self.threshold_values[0]
815            if isinstance(k, list):
816                k = ", ".join(map(str, k))
817            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
818            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
819            f1 = 2 * recall * precision / (recall + precision + 1e-7)
820        return f1

F1 score.

F1( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, tag_average: str = 'micro', threshold_value: Union[float, List] = None, integration_interval: int = 0)
756    def __init__(
757        self,
758        num_classes: int,
759        ignore_index: int = -100,
760        average: str = "macro",
761        ignored_classes: Set = None,
762        exclusive: bool = True,
763        tag_average: str = "micro",
764        threshold_value: Union[float, List] = None,
765        integration_interval: int = 0,
766    ):
767        """Initialize the class.
768
769        Parameters
770        ----------
771        ignore_index : int, default -100
772            the class index that indicates ignored samples
773        average: {'macro', 'micro', 'none'}
774            method for averaging across classes
775        num_classes : int, optional
776            number of classes (not necessary if main_class is not None)
777        ignored_classes : set, optional
778            a set of class ids to ignore in calculation
779        exclusive: bool, default True
780            set to False for multi-label classification tasks
781        tag_average: {'micro', 'macro', 'none'}
782            method for averaging across meta tags (if given)
783        threshold_value : float | list, optional
784            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
785            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
786            under the same index
787        integration_interval : int, default 0
788            the number of frames to integrate over (0 means no integration)
789
790        """
791        if threshold_value is None:
792            threshold_value = 0.5
793        super().__init__(
794            num_classes,
795            ignore_index,
796            average,
797            ignored_classes,
798            exclusive,
799            tag_average=tag_average,
800            threshold_values=[threshold_value],
801            integration_interval=integration_interval,
802        )

Initialize the class.

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index integration_interval : int, default 0 the number of frames to integrate over (0 means no integration)

class SegmentalF1(_ClassificationMetric):
823class SegmentalF1(_ClassificationMetric):
824    """Segmental F1 score (not advised for training)."""
825
826    segmental = True
827
828    def __init__(
829        self,
830        num_classes: int,
831        ignore_index: int = -100,
832        average: str = "macro",
833        ignored_classes: Set = None,
834        exclusive: bool = True,
835        iou_threshold: float = 0.5,
836        tag_average: str = "micro",
837        threshold_value: Union[float, List] = None,
838    ):
839        """Initialize the class.
840
841        Parameters
842        ----------
843        ignore_index : int, default -100
844            the class index that indicates ignored samples
845        average: {'macro', 'micro', 'none'}
846            method for averaging across classes
847        num_classes : int, optional
848            number of classes (not necessary if main_class is not None)
849        ignored_classes : set, optional
850            a set of class ids to ignore in calculation
851        exclusive: bool, default True
852            set to False for multi-label classification tasks
853        iou_threshold : float, default 0.5
854            if segmental is true, intervals with IoU larger than this threshold are considered correct
855        tag_average: {'micro', 'macro', 'none'}
856            method for averaging across meta tags (if given)
857        threshold_value : float | list, optional
858            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
859            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
860            under the same index
861
862        """
863        if threshold_value is None:
864            threshold_value = 0.5
865        super().__init__(
866            num_classes,
867            ignore_index,
868            average,
869            ignored_classes,
870            exclusive,
871            iou_threshold,
872            tag_average,
873            [threshold_value],
874        )
875
876    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
877        """Calculate the metric value from true and false positive and negative rates."""
878        if self.optimize:
879            scores = []
880            for k in self.threshold_values:
881                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
882                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
883                scores.append(2 * recall * precision / (recall + precision + 1e-7))
884            f1 = max(scores)
885        else:
886            k = self.threshold_values[0]
887            if isinstance(k, list):
888                k = ", ".join(map(str, k))
889            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
890            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
891            f1 = 2 * recall * precision / (recall + precision + 1e-7)
892        return f1

Segmental F1 score (not advised for training).

SegmentalF1( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, iou_threshold: float = 0.5, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
828    def __init__(
829        self,
830        num_classes: int,
831        ignore_index: int = -100,
832        average: str = "macro",
833        ignored_classes: Set = None,
834        exclusive: bool = True,
835        iou_threshold: float = 0.5,
836        tag_average: str = "micro",
837        threshold_value: Union[float, List] = None,
838    ):
839        """Initialize the class.
840
841        Parameters
842        ----------
843        ignore_index : int, default -100
844            the class index that indicates ignored samples
845        average: {'macro', 'micro', 'none'}
846            method for averaging across classes
847        num_classes : int, optional
848            number of classes (not necessary if main_class is not None)
849        ignored_classes : set, optional
850            a set of class ids to ignore in calculation
851        exclusive: bool, default True
852            set to False for multi-label classification tasks
853        iou_threshold : float, default 0.5
854            if segmental is true, intervals with IoU larger than this threshold are considered correct
855        tag_average: {'micro', 'macro', 'none'}
856            method for averaging across meta tags (if given)
857        threshold_value : float | list, optional
858            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
859            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
860            under the same index
861
862        """
863        if threshold_value is None:
864            threshold_value = 0.5
865        super().__init__(
866            num_classes,
867            ignore_index,
868            average,
869            ignored_classes,
870            exclusive,
871            iou_threshold,
872            tag_average,
873            [threshold_value],
874        )

Initialize the class.

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks iou_threshold : float, default 0.5 if segmental is true, intervals with IoU larger than this threshold are considered correct tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

segmental = True

If True, the metric will be calculated over segments; otherwise over frames.

class Fbeta(_ClassificationMetric):
895class Fbeta(_ClassificationMetric):
896    """F-beta score."""
897
898    def __init__(
899        self,
900        beta: float = 1,
901        ignore_index: int = -100,
902        average: str = "macro",
903        num_classes: int = None,
904        ignored_classes: Set = None,
905        tag_average: str = "micro",
906        exclusive: bool = True,
907        threshold_value: float = 0.5,
908    ):
909        """Initialize the class.
910
911        Parameters
912        ----------
913        beta : float, default 1
914            the beta parameter
915        ignore_index : int, default -100
916            the class index that indicates ignored samples
917        average: {'macro', 'micro', 'none'}
918            method for averaging across classes
919        num_classes : int, optional
920            number of classes (not necessary if main_class is not None)
921        ignored_classes : set, optional
922            a set of class ids to ignore in calculation
923        exclusive: bool, default True
924            set to False for multi-label classification tasks
925        tag_average: {'micro', 'macro', 'none'}
926            method for averaging across meta tags (if given)
927        threshold_value : float | list, optional
928            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
929            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
930            under the same index
931
932        """
933        if threshold_value is None:
934            threshold_value = 0.5
935        self.beta2 = beta**2
936        super().__init__(
937            num_classes,
938            ignore_index,
939            average,
940            ignored_classes,
941            exclusive,
942            tag_average=tag_average,
943            threshold_values=[threshold_value],
944        )
945
946    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
947        """Calculate the metric value from true and false positive and negative rates."""
948        if self.optimize:
949            scores = []
950            for k in self.threshold_values:
951                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
952                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
953                scores.append(
954                    (
955                        (1 + self.beta2)
956                        * precision
957                        * recall
958                        / (self.beta2 * precision + recall + 1e-7)
959                    )
960                )
961            f1 = max(scores)
962        else:
963            k = self.threshold_values[0]
964            if isinstance(k, list):
965                k = ", ".join(map(str, k))
966            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
967            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
968            f1 = (
969                (1 + self.beta2)
970                * precision
971                * recall
972                / (self.beta2 * precision + recall + 1e-7)
973            )
974        return f1

F-beta score.

Fbeta( beta: float = 1, ignore_index: int = -100, average: str = 'macro', num_classes: int = None, ignored_classes: Set = None, tag_average: str = 'micro', exclusive: bool = True, threshold_value: float = 0.5)
898    def __init__(
899        self,
900        beta: float = 1,
901        ignore_index: int = -100,
902        average: str = "macro",
903        num_classes: int = None,
904        ignored_classes: Set = None,
905        tag_average: str = "micro",
906        exclusive: bool = True,
907        threshold_value: float = 0.5,
908    ):
909        """Initialize the class.
910
911        Parameters
912        ----------
913        beta : float, default 1
914            the beta parameter
915        ignore_index : int, default -100
916            the class index that indicates ignored samples
917        average: {'macro', 'micro', 'none'}
918            method for averaging across classes
919        num_classes : int, optional
920            number of classes (not necessary if main_class is not None)
921        ignored_classes : set, optional
922            a set of class ids to ignore in calculation
923        exclusive: bool, default True
924            set to False for multi-label classification tasks
925        tag_average: {'micro', 'macro', 'none'}
926            method for averaging across meta tags (if given)
927        threshold_value : float | list, optional
928            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
929            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
930            under the same index
931
932        """
933        if threshold_value is None:
934            threshold_value = 0.5
935        self.beta2 = beta**2
936        super().__init__(
937            num_classes,
938            ignore_index,
939            average,
940            ignored_classes,
941            exclusive,
942            tag_average=tag_average,
943            threshold_values=[threshold_value],
944        )

Initialize the class.

Parameters

beta : float, default 1 the beta parameter ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

beta2
class SegmentalFbeta(_ClassificationMetric):
 977class SegmentalFbeta(_ClassificationMetric):
 978    """Segmental F-beta score (not advised for training)."""
 979
 980    segmental = True
 981
 982    def __init__(
 983        self,
 984        beta: float = 1,
 985        ignore_index: int = -100,
 986        average: str = "macro",
 987        num_classes: int = None,
 988        ignored_classes: Set = None,
 989        iou_threshold: float = 0.5,
 990        tag_average: str = "micro",
 991        exclusive: bool = True,
 992        threshold_value: float = 0.5,
 993    ):
 994        """Initialize the class.
 995
 996        Parameters
 997        ----------
 998        beta : float, default 1
 999            the beta parameter
1000        ignore_index : int, default -100
1001            the class index that indicates ignored samples
1002        average: {'macro', 'micro', 'none'}
1003            method for averaging across classes
1004        num_classes : int, optional
1005            number of classes (not necessary if main_class is not None)
1006        ignored_classes : set, optional
1007            a set of class ids to ignore in calculation
1008        exclusive: bool, default True
1009            set to False for multi-label classification tasks
1010        iou_threshold : float, default 0.5
1011            if segmental is true, intervals with IoU larger than this threshold are considered correct
1012        tag_average: {'micro', 'macro', 'none'}
1013            method for averaging across meta tags (if given)
1014        threshold_value : float | list, optional
1015            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
1016            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
1017            under the same index
1018
1019        """
1020        if threshold_value is None:
1021            threshold_value = 0.5
1022        self.beta2 = beta**2
1023        super().__init__(
1024            num_classes,
1025            ignore_index,
1026            average,
1027            ignored_classes,
1028            exclusive,
1029            iou_threshold=iou_threshold,
1030            tag_average=tag_average,
1031            threshold_values=[threshold_value],
1032        )
1033
1034    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1035        """Calculate the metric value from true and false positive and negative rates."""
1036        if self.optimize:
1037            scores = []
1038            for k in self.threshold_values:
1039                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1040                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1041                scores.append(
1042                    (
1043                        (1 + self.beta2)
1044                        * precision
1045                        * recall
1046                        / (self.beta2 * precision + recall + 1e-7)
1047                    )
1048                )
1049            f1 = max(scores)
1050        else:
1051            k = self.threshold_values[0]
1052            if isinstance(k, list):
1053                k = ", ".join(map(str, k))
1054            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1055            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1056            f1 = (
1057                (1 + self.beta2)
1058                * precision
1059                * recall
1060                / (self.beta2 * precision + recall + 1e-7)
1061            )
1062        return f1

Segmental F-beta score (not advised for training).

SegmentalFbeta( beta: float = 1, ignore_index: int = -100, average: str = 'macro', num_classes: int = None, ignored_classes: Set = None, iou_threshold: float = 0.5, tag_average: str = 'micro', exclusive: bool = True, threshold_value: float = 0.5)
 982    def __init__(
 983        self,
 984        beta: float = 1,
 985        ignore_index: int = -100,
 986        average: str = "macro",
 987        num_classes: int = None,
 988        ignored_classes: Set = None,
 989        iou_threshold: float = 0.5,
 990        tag_average: str = "micro",
 991        exclusive: bool = True,
 992        threshold_value: float = 0.5,
 993    ):
 994        """Initialize the class.
 995
 996        Parameters
 997        ----------
 998        beta : float, default 1
 999            the beta parameter
1000        ignore_index : int, default -100
1001            the class index that indicates ignored samples
1002        average: {'macro', 'micro', 'none'}
1003            method for averaging across classes
1004        num_classes : int, optional
1005            number of classes (not necessary if main_class is not None)
1006        ignored_classes : set, optional
1007            a set of class ids to ignore in calculation
1008        exclusive: bool, default True
1009            set to False for multi-label classification tasks
1010        iou_threshold : float, default 0.5
1011            if segmental is true, intervals with IoU larger than this threshold are considered correct
1012        tag_average: {'micro', 'macro', 'none'}
1013            method for averaging across meta tags (if given)
1014        threshold_value : float | list, optional
1015            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
1016            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
1017            under the same index
1018
1019        """
1020        if threshold_value is None:
1021            threshold_value = 0.5
1022        self.beta2 = beta**2
1023        super().__init__(
1024            num_classes,
1025            ignore_index,
1026            average,
1027            ignored_classes,
1028            exclusive,
1029            iou_threshold=iou_threshold,
1030            tag_average=tag_average,
1031            threshold_values=[threshold_value],
1032        )

Initialize the class.

Parameters

beta : float, default 1 the beta parameter ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks iou_threshold : float, default 0.5 if segmental is true, intervals with IoU larger than this threshold are considered correct tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

segmental = True

If True, the metric will be calculated over segments; otherwise over frames.

beta2
class SemiSegmentalRecall(_SemiSegmentalMetric):
1317class SemiSegmentalRecall(_SemiSegmentalMetric):
1318    """Semi-segmental recall (not advised for training).
1319
1320    A metric in-between segmental and frame-wise recall.
1321
1322    This metric follows the following algorithm:
1323    1) smooth over too-short intervals, both in ground truth and in prediction (first remove
1324        groups of zeros shorter than `smooth_interval` and then do the same with groups of ones),
1325    2) add `delta` frames to each ground truth interval at both ends and count the number of predicted
1326        positive frames at the resulting intervals (intersection),
1327    3) calculate the threshold for each interval as
1328        `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where
1329        `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval
1330        before `delta` was added,
1331    4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`),
1332        and otherwise as false negative (`FN`),
1333    5) the final metric value is computed as `TP / (TP + FN)`.
1334    """
1335
1336    def __init__(
1337        self,
1338        num_classes: int,
1339        ignore_index: int = -100,
1340        ignored_classes: Set = None,
1341        exclusive: bool = True,
1342        average: str = "macro",
1343        tag_average: str = "micro",
1344        delta: int = 0,
1345        smooth_interval: int = 0,
1346        iou_threshold_long: float = 0.5,
1347        iou_threshold_short: float = 0.5,
1348        short_length: int = 30,
1349        long_length: int = 300,
1350        threshold_value: Union[float, List] = None,
1351    ) -> None:
1352        """Initialize the class.
1353
1354        Parameters
1355        ----------
1356        num_classes : int
1357            the number of classes in the dataset
1358        ignore_index : int, default -100
1359            the ground truth label to ignore
1360        ignored_classes : set, optional
1361            the class indices to ignore in computation
1362        exclusive : bool, default True
1363            `False` for multi-label classification tasks
1364        average : {"macro", "micro", "none"}
1365            the method to average the results over classes
1366        tag_average : {"macro", "micro", "none"}
1367            the method to average the results over meta tags (if given)
1368        delta : int, default 0
1369            the number of frames to add to each ground truth interval before computing the intersection,
1370            see description of the class for details
1371        smooth_interval : int, default 0
1372            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1373            see description of the class for details
1374        iou_threshold_long : float, default 0.5
1375            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1376            see description of the class for details
1377        iou_threshold_short : float, default 0.5
1378            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1379            see description of the class for details
1380        short_length : int, default 30
1381            the threshold number of frames for short intervals that will have an intersection threshold of
1382            `iou_threshold_short`, see description of the class for details
1383        long_length : int, default 300
1384            the threshold number of frames for long intervals that will have an intersection threshold of
1385            `iou_threshold_long`, see description of the class for details
1386        threshold_value : float | list, optional
1387            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1388
1389        """
1390        super().__init__(
1391            num_classes,
1392            ignore_index,
1393            ignored_classes,
1394            exclusive,
1395            average,
1396            tag_average,
1397            delta,
1398            smooth_interval,
1399            iou_threshold_long,
1400            iou_threshold_short,
1401            short_length,
1402            long_length,
1403            [threshold_value],
1404        )
1405
1406    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1407        """Calculate the metric value from true and false positive and negative rates."""
1408        k = self.threshold_values[0]
1409        if isinstance(k, list):
1410            k = ", ".join(map(str, k))
1411        return tp[k] / (tp[k] + fn[k] + 1e-7)

Semi-segmental recall (not advised for training).

A metric in-between segmental and frame-wise recall.

This metric follows the following algorithm: 1) smooth over too-short intervals, both in ground truth and in prediction (first remove groups of zeros shorter than smooth_interval and then do the same with groups of ones), 2) add delta frames to each ground truth interval at both ends and count the number of predicted positive frames at the resulting intervals (intersection), 3) calculate the threshold for each interval as t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short)), where a = 2 / (long_length - short_length), b = 1 - a * long_length, x is the length of the interval before delta was added, 4) for each interval, if intersection is higher than t * x, the interval is labeled as true positive (TP), and otherwise as false negative (FN), 5) the final metric value is computed as TP / (TP + FN).

SemiSegmentalRecall( num_classes: int, ignore_index: int = -100, ignored_classes: Set = None, exclusive: bool = True, average: str = 'macro', tag_average: str = 'micro', delta: int = 0, smooth_interval: int = 0, iou_threshold_long: float = 0.5, iou_threshold_short: float = 0.5, short_length: int = 30, long_length: int = 300, threshold_value: Union[float, List] = None)
1336    def __init__(
1337        self,
1338        num_classes: int,
1339        ignore_index: int = -100,
1340        ignored_classes: Set = None,
1341        exclusive: bool = True,
1342        average: str = "macro",
1343        tag_average: str = "micro",
1344        delta: int = 0,
1345        smooth_interval: int = 0,
1346        iou_threshold_long: float = 0.5,
1347        iou_threshold_short: float = 0.5,
1348        short_length: int = 30,
1349        long_length: int = 300,
1350        threshold_value: Union[float, List] = None,
1351    ) -> None:
1352        """Initialize the class.
1353
1354        Parameters
1355        ----------
1356        num_classes : int
1357            the number of classes in the dataset
1358        ignore_index : int, default -100
1359            the ground truth label to ignore
1360        ignored_classes : set, optional
1361            the class indices to ignore in computation
1362        exclusive : bool, default True
1363            `False` for multi-label classification tasks
1364        average : {"macro", "micro", "none"}
1365            the method to average the results over classes
1366        tag_average : {"macro", "micro", "none"}
1367            the method to average the results over meta tags (if given)
1368        delta : int, default 0
1369            the number of frames to add to each ground truth interval before computing the intersection,
1370            see description of the class for details
1371        smooth_interval : int, default 0
1372            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1373            see description of the class for details
1374        iou_threshold_long : float, default 0.5
1375            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1376            see description of the class for details
1377        iou_threshold_short : float, default 0.5
1378            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1379            see description of the class for details
1380        short_length : int, default 30
1381            the threshold number of frames for short intervals that will have an intersection threshold of
1382            `iou_threshold_short`, see description of the class for details
1383        long_length : int, default 300
1384            the threshold number of frames for long intervals that will have an intersection threshold of
1385            `iou_threshold_long`, see description of the class for details
1386        threshold_value : float | list, optional
1387            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1388
1389        """
1390        super().__init__(
1391            num_classes,
1392            ignore_index,
1393            ignored_classes,
1394            exclusive,
1395            average,
1396            tag_average,
1397            delta,
1398            smooth_interval,
1399            iou_threshold_long,
1400            iou_threshold_short,
1401            short_length,
1402            long_length,
1403            [threshold_value],
1404        )

Initialize the class.

Parameters

num_classes : int the number of classes in the dataset ignore_index : int, default -100 the ground truth label to ignore ignored_classes : set, optional the class indices to ignore in computation exclusive : bool, default True False for multi-label classification tasks average : {"macro", "micro", "none"} the method to average the results over classes tag_average : {"macro", "micro", "none"} the method to average the results over meta tags (if given) delta : int, default 0 the number of frames to add to each ground truth interval before computing the intersection, see description of the class for details smooth_interval : int, default 0 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, see description of the class for details iou_threshold_long : float, default 0.5 the intersection threshold for segments longer than long_length frames (between 0 and 1), see description of the class for details iou_threshold_short : float, default 0.5 the intersection threshold for segments shorter than short_length frames (between 0 and 1), see description of the class for details short_length : int, default 30 the threshold number of frames for short intervals that will have an intersection threshold of iou_threshold_short, see description of the class for details long_length : int, default 300 the threshold number of frames for long intervals that will have an intersection threshold of iou_threshold_long, see description of the class for details threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)

class SemiSegmentalPrecision(_SemiSegmentalMetric):
1414class SemiSegmentalPrecision(_SemiSegmentalMetric):
1415    """Semi-segmental precision (not advised for training).
1416
1417    A metric in-between segmental and frame-wise precision.
1418
1419    This metric follows the following algorithm:
1420    1) smooth over too-short intervals, both in ground truth and in prediction (first remove
1421        groups of zeros shorter than `smooth_interval` and then do the same with groups of ones),
1422    2) add `delta` frames to each predicted interval at both ends and count the number of ground truth
1423        positive frames at the resulting intervals (intersection),
1424    3) calculate the threshold for each interval as
1425        `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where
1426        `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval
1427        before `delta` was added,
1428    4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`),
1429        and otherwise as false positive (`FP`),
1430    5) the final metric value is computed as `TP / (TP + FP)`.
1431
1432    """
1433    def __init__(
1434        self,
1435        num_classes: int,
1436        ignore_index: int = -100,
1437        ignored_classes: Set = None,
1438        exclusive: bool = True,
1439        average: str = "macro",
1440        tag_average: str = "micro",
1441        delta: int = 0,
1442        smooth_interval: int = 0,
1443        iou_threshold_long: float = 0.5,
1444        iou_threshold_short: float = 0.5,
1445        short_length: int = 30,
1446        long_length: int = 300,
1447        threshold_value: Union[float, List] = None,
1448    ) -> None:
1449        """Initialize the class.
1450
1451        Parameters
1452        ----------
1453        num_classes : int
1454            the number of classes in the dataset
1455        ignore_index : int, default -100
1456            the ground truth label to ignore
1457        ignored_classes : set, optional
1458            the class indices to ignore in computation
1459        exclusive : bool, default True
1460            `False` for multi-label classification tasks
1461        average : {"macro", "micro", "none"}
1462            the method to average the results over classes
1463        tag_average : {"macro", "micro", "none"}
1464            the method to average the results over meta tags (if given)
1465        delta : int, default 0
1466            the number of frames to add to each ground truth interval before computing the intersection,
1467            see description of the class for details
1468        smooth_interval : int, default 0
1469            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1470            see description of the class for details
1471        iou_threshold_long : float, default 0.5
1472            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1473            see description of the class for details
1474        iou_threshold_short : float, default 0.5
1475            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1476            see description of the class for details
1477        short_length : int, default 30
1478            the threshold number of frames for short intervals that will have an intersection threshold of
1479            `iou_threshold_short`, see description of the class for details
1480        long_length : int, default 300
1481            the threshold number of frames for long intervals that will have an intersection threshold of
1482            `iou_threshold_long`, see description of the class for details
1483        threshold_value : float | list, optional
1484            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1485
1486        """
1487        super().__init__(
1488            num_classes,
1489            ignore_index,
1490            ignored_classes,
1491            exclusive,
1492            average,
1493            tag_average,
1494            delta,
1495            smooth_interval,
1496            iou_threshold_long,
1497            iou_threshold_short,
1498            short_length,
1499            long_length,
1500            [threshold_value],
1501        )
1502
1503    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1504        """Calculate the metric value from true and false positive and negative rates."""
1505        k = self.threshold_values[0]
1506        if isinstance(k, list):
1507            k = ", ".join(map(str, k))
1508        return tn[k] / (tn[k] + fp[k] + 1e-7)

Semi-segmental precision (not advised for training).

A metric in-between segmental and frame-wise precision.

This metric follows the following algorithm: 1) smooth over too-short intervals, both in ground truth and in prediction (first remove groups of zeros shorter than smooth_interval and then do the same with groups of ones), 2) add delta frames to each predicted interval at both ends and count the number of ground truth positive frames at the resulting intervals (intersection), 3) calculate the threshold for each interval as t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short)), where a = 2 / (long_length - short_length), b = 1 - a * long_length, x is the length of the interval before delta was added, 4) for each interval, if intersection is higher than t * x, the interval is labeled as true positive (TP), and otherwise as false positive (FP), 5) the final metric value is computed as TP / (TP + FP).

SemiSegmentalPrecision( num_classes: int, ignore_index: int = -100, ignored_classes: Set = None, exclusive: bool = True, average: str = 'macro', tag_average: str = 'micro', delta: int = 0, smooth_interval: int = 0, iou_threshold_long: float = 0.5, iou_threshold_short: float = 0.5, short_length: int = 30, long_length: int = 300, threshold_value: Union[float, List] = None)
1433    def __init__(
1434        self,
1435        num_classes: int,
1436        ignore_index: int = -100,
1437        ignored_classes: Set = None,
1438        exclusive: bool = True,
1439        average: str = "macro",
1440        tag_average: str = "micro",
1441        delta: int = 0,
1442        smooth_interval: int = 0,
1443        iou_threshold_long: float = 0.5,
1444        iou_threshold_short: float = 0.5,
1445        short_length: int = 30,
1446        long_length: int = 300,
1447        threshold_value: Union[float, List] = None,
1448    ) -> None:
1449        """Initialize the class.
1450
1451        Parameters
1452        ----------
1453        num_classes : int
1454            the number of classes in the dataset
1455        ignore_index : int, default -100
1456            the ground truth label to ignore
1457        ignored_classes : set, optional
1458            the class indices to ignore in computation
1459        exclusive : bool, default True
1460            `False` for multi-label classification tasks
1461        average : {"macro", "micro", "none"}
1462            the method to average the results over classes
1463        tag_average : {"macro", "micro", "none"}
1464            the method to average the results over meta tags (if given)
1465        delta : int, default 0
1466            the number of frames to add to each ground truth interval before computing the intersection,
1467            see description of the class for details
1468        smooth_interval : int, default 0
1469            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1470            see description of the class for details
1471        iou_threshold_long : float, default 0.5
1472            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1473            see description of the class for details
1474        iou_threshold_short : float, default 0.5
1475            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1476            see description of the class for details
1477        short_length : int, default 30
1478            the threshold number of frames for short intervals that will have an intersection threshold of
1479            `iou_threshold_short`, see description of the class for details
1480        long_length : int, default 300
1481            the threshold number of frames for long intervals that will have an intersection threshold of
1482            `iou_threshold_long`, see description of the class for details
1483        threshold_value : float | list, optional
1484            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1485
1486        """
1487        super().__init__(
1488            num_classes,
1489            ignore_index,
1490            ignored_classes,
1491            exclusive,
1492            average,
1493            tag_average,
1494            delta,
1495            smooth_interval,
1496            iou_threshold_long,
1497            iou_threshold_short,
1498            short_length,
1499            long_length,
1500            [threshold_value],
1501        )

Initialize the class.

Parameters

num_classes : int the number of classes in the dataset ignore_index : int, default -100 the ground truth label to ignore ignored_classes : set, optional the class indices to ignore in computation exclusive : bool, default True False for multi-label classification tasks average : {"macro", "micro", "none"} the method to average the results over classes tag_average : {"macro", "micro", "none"} the method to average the results over meta tags (if given) delta : int, default 0 the number of frames to add to each ground truth interval before computing the intersection, see description of the class for details smooth_interval : int, default 0 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, see description of the class for details iou_threshold_long : float, default 0.5 the intersection threshold for segments longer than long_length frames (between 0 and 1), see description of the class for details iou_threshold_short : float, default 0.5 the intersection threshold for segments shorter than short_length frames (between 0 and 1), see description of the class for details short_length : int, default 30 the threshold number of frames for short intervals that will have an intersection threshold of iou_threshold_short, see description of the class for details long_length : int, default 300 the threshold number of frames for long intervals that will have an intersection threshold of iou_threshold_long, see description of the class for details threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)

class SemiSegmentalF1(_SemiSegmentalMetric):
1511class SemiSegmentalF1(_SemiSegmentalMetric):
1512    """The F1 score for semi-segmental recall and precision (not advised for training)."""
1513
1514    def __init__(
1515        self,
1516        num_classes: int,
1517        ignore_index: int = -100,
1518        ignored_classes: Set = None,
1519        exclusive: bool = True,
1520        average: str = "macro",
1521        tag_average: str = "micro",
1522        delta: int = 0,
1523        smooth_interval: int = 0,
1524        iou_threshold_long: float = 0.5,
1525        iou_threshold_short: float = 0.5,
1526        short_length: int = 30,
1527        long_length: int = 300,
1528        threshold_value: Union[float, List] = None,
1529    ) -> None:
1530        """Initialize the class.
1531
1532        Parameters
1533        ----------
1534        num_classes : int
1535            the number of classes in the dataset
1536        ignore_index : int, default -100
1537            the ground truth label to ignore
1538        ignored_classes : set, optional
1539            the class indices to ignore in computation
1540        exclusive : bool, default True
1541            `False` for multi-label classification tasks
1542        average : {"macro", "micro", "none"}
1543            the method to average the results over classes
1544        tag_average : {"macro", "micro", "none"}
1545            the method to average the results over meta tags (if given)
1546        delta : int, default 0
1547            the number of frames to add to each ground truth interval before computing the intersection,
1548            see description of the class for details
1549        smooth_interval : int, default 0
1550            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1551            see description of the class for details
1552        iou_threshold_long : float, default 0.5
1553            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1554            see description of the class for details
1555        iou_threshold_short : float, default 0.5
1556            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1557            see description of the class for details
1558        short_length : int, default 30
1559            the threshold number of frames for short intervals that will have an intersection threshold of
1560            `iou_threshold_short`, see description of the class for details
1561        long_length : int, default 300
1562            the threshold number of frames for long intervals that will have an intersection threshold of
1563            `iou_threshold_long`, see description of the class for details
1564        threshold_value : float | list, optional
1565            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1566
1567        """
1568        super().__init__(
1569            num_classes,
1570            ignore_index,
1571            ignored_classes,
1572            exclusive,
1573            average,
1574            tag_average,
1575            delta,
1576            smooth_interval,
1577            iou_threshold_long,
1578            iou_threshold_short,
1579            short_length,
1580            long_length,
1581            [threshold_value],
1582        )
1583
1584    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1585        """Calculate the metric value from true and false positive and negative rates."""
1586        if self.optimize:
1587            scores = []
1588            for k in self.threshold_values:
1589                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1590                precision = tn[k] / (tn[k] + fp[k] + 1e-7)
1591                scores.append(2 * recall * precision / (recall + precision + 1e-7))
1592            f1 = max(scores)
1593        else:
1594            k = self.threshold_values[0]
1595            if isinstance(k, list):
1596                k = ", ".join(map(str, k))
1597            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1598            precision = tn[k] / (tn[k] + fp[k] + 1e-7)
1599            f1 = 2 * recall * precision / (recall + precision + 1e-7)
1600        return f1

The F1 score for semi-segmental recall and precision (not advised for training).

SemiSegmentalF1( num_classes: int, ignore_index: int = -100, ignored_classes: Set = None, exclusive: bool = True, average: str = 'macro', tag_average: str = 'micro', delta: int = 0, smooth_interval: int = 0, iou_threshold_long: float = 0.5, iou_threshold_short: float = 0.5, short_length: int = 30, long_length: int = 300, threshold_value: Union[float, List] = None)
1514    def __init__(
1515        self,
1516        num_classes: int,
1517        ignore_index: int = -100,
1518        ignored_classes: Set = None,
1519        exclusive: bool = True,
1520        average: str = "macro",
1521        tag_average: str = "micro",
1522        delta: int = 0,
1523        smooth_interval: int = 0,
1524        iou_threshold_long: float = 0.5,
1525        iou_threshold_short: float = 0.5,
1526        short_length: int = 30,
1527        long_length: int = 300,
1528        threshold_value: Union[float, List] = None,
1529    ) -> None:
1530        """Initialize the class.
1531
1532        Parameters
1533        ----------
1534        num_classes : int
1535            the number of classes in the dataset
1536        ignore_index : int, default -100
1537            the ground truth label to ignore
1538        ignored_classes : set, optional
1539            the class indices to ignore in computation
1540        exclusive : bool, default True
1541            `False` for multi-label classification tasks
1542        average : {"macro", "micro", "none"}
1543            the method to average the results over classes
1544        tag_average : {"macro", "micro", "none"}
1545            the method to average the results over meta tags (if given)
1546        delta : int, default 0
1547            the number of frames to add to each ground truth interval before computing the intersection,
1548            see description of the class for details
1549        smooth_interval : int, default 0
1550            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1551            see description of the class for details
1552        iou_threshold_long : float, default 0.5
1553            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1554            see description of the class for details
1555        iou_threshold_short : float, default 0.5
1556            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1557            see description of the class for details
1558        short_length : int, default 30
1559            the threshold number of frames for short intervals that will have an intersection threshold of
1560            `iou_threshold_short`, see description of the class for details
1561        long_length : int, default 300
1562            the threshold number of frames for long intervals that will have an intersection threshold of
1563            `iou_threshold_long`, see description of the class for details
1564        threshold_value : float | list, optional
1565            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1566
1567        """
1568        super().__init__(
1569            num_classes,
1570            ignore_index,
1571            ignored_classes,
1572            exclusive,
1573            average,
1574            tag_average,
1575            delta,
1576            smooth_interval,
1577            iou_threshold_long,
1578            iou_threshold_short,
1579            short_length,
1580            long_length,
1581            [threshold_value],
1582        )

Initialize the class.

Parameters

num_classes : int the number of classes in the dataset ignore_index : int, default -100 the ground truth label to ignore ignored_classes : set, optional the class indices to ignore in computation exclusive : bool, default True False for multi-label classification tasks average : {"macro", "micro", "none"} the method to average the results over classes tag_average : {"macro", "micro", "none"} the method to average the results over meta tags (if given) delta : int, default 0 the number of frames to add to each ground truth interval before computing the intersection, see description of the class for details smooth_interval : int, default 0 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, see description of the class for details iou_threshold_long : float, default 0.5 the intersection threshold for segments longer than long_length frames (between 0 and 1), see description of the class for details iou_threshold_short : float, default 0.5 the intersection threshold for segments shorter than short_length frames (between 0 and 1), see description of the class for details short_length : int, default 30 the threshold number of frames for short intervals that will have an intersection threshold of iou_threshold_short, see description of the class for details long_length : int, default 300 the threshold number of frames for long intervals that will have an intersection threshold of iou_threshold_long, see description of the class for details threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)

class SemiSegmentalPR_AUC(_SemiSegmentalMetric):
1603class SemiSegmentalPR_AUC(_SemiSegmentalMetric):
1604    """The area under the precision-recall curve for semi-segmental metrics (not advised for training)."""
1605
1606    def __init__(
1607        self,
1608        num_classes: int,
1609        ignore_index: int = -100,
1610        ignored_classes: Set = None,
1611        exclusive: bool = True,
1612        average: str = "macro",
1613        tag_average: str = "micro",
1614        delta: int = 0,
1615        smooth_interval: int = 0,
1616        iou_threshold_long: float = 0.5,
1617        iou_threshold_short: float = 0.5,
1618        short_length: int = 30,
1619        long_length: int = 300,
1620        threshold_step: float = 0.1,
1621    ) -> None:
1622        super().__init__(
1623            num_classes,
1624            ignore_index,
1625            ignored_classes,
1626            exclusive,
1627            average,
1628            tag_average,
1629            delta,
1630            smooth_interval,
1631            iou_threshold_long,
1632            iou_threshold_short,
1633            short_length,
1634            long_length,
1635            list(np.arange(0, 1, threshold_step)),
1636        )
1637
1638    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1639        """Calculate the metric value from true and false positive and negative rates."""
1640        precisions = []
1641        recalls = []
1642        for k in sorted(self.threshold_values):
1643            precisions.append(tn[k] / (tn[k] + fp[k] + 1e-7))
1644            recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7))
1645        return metrics.auc(x=recalls, y=precisions)

The area under the precision-recall curve for semi-segmental metrics (not advised for training).

SemiSegmentalPR_AUC( num_classes: int, ignore_index: int = -100, ignored_classes: Set = None, exclusive: bool = True, average: str = 'macro', tag_average: str = 'micro', delta: int = 0, smooth_interval: int = 0, iou_threshold_long: float = 0.5, iou_threshold_short: float = 0.5, short_length: int = 30, long_length: int = 300, threshold_step: float = 0.1)
1606    def __init__(
1607        self,
1608        num_classes: int,
1609        ignore_index: int = -100,
1610        ignored_classes: Set = None,
1611        exclusive: bool = True,
1612        average: str = "macro",
1613        tag_average: str = "micro",
1614        delta: int = 0,
1615        smooth_interval: int = 0,
1616        iou_threshold_long: float = 0.5,
1617        iou_threshold_short: float = 0.5,
1618        short_length: int = 30,
1619        long_length: int = 300,
1620        threshold_step: float = 0.1,
1621    ) -> None:
1622        super().__init__(
1623            num_classes,
1624            ignore_index,
1625            ignored_classes,
1626            exclusive,
1627            average,
1628            tag_average,
1629            delta,
1630            smooth_interval,
1631            iou_threshold_long,
1632            iou_threshold_short,
1633            short_length,
1634            long_length,
1635            list(np.arange(0, 1, threshold_step)),
1636        )

Initialize the class.

Parameters

num_classes : int the number of classes in the dataset ignore_index : int, default -100 the ground truth label to ignore ignored_classes : set, optional the class indices to ignore in computation exclusive : bool, default True False for multi-label classification tasks average : {"macro", "micro", "none"} the method to average the results over classes tag_average : {"macro", "micro", "none"} the method to average the results over meta tags (if given) delta : int, default 0 the number of frames to add to each ground truth interval before computing the intersection, see description of the class for details smooth_interval : int, default 0 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, see description of the class for details iou_threshold_long : float, default 0.5 the intersection threshold for segments longer than long_length frames (between 0 and 1), see description of the class for details iou_threshold_short : float, default 0.5 the intersection threshold for segments shorter than short_length frames (between 0 and 1), see description of the class for details short_length : int, default 30 the threshold number of frames for short intervals that will have an intersection threshold of iou_threshold_short, see description of the class for details long_length : int, default 300 the threshold number of frames for long intervals that will have an intersection threshold of iou_threshold_long, see description of the class for details threshold_values : list, optional the decision threshold values (cannot be defined for exclusive classification, and 0.5 be default)

class Accuracy(dlc2action.metric.base_metric.Metric):
1648class Accuracy(Metric):
1649    """Accuracy."""
1650
1651    def __init__(self, ignore_index=-100):
1652        """Initialize the class.
1653
1654        Parameters
1655        ----------
1656        ignore_index: int
1657            the class index that indicates ignored sample
1658
1659        """
1660        super().__init__()
1661        self.ignore_index = ignore_index
1662
1663    def reset(self) -> None:
1664        """Reset the internal state (at the beginning of an epoch)."""
1665        self.total = 0
1666        self.correct = 0
1667
1668    def calculate(self) -> float:
1669        """Calculate the metric value.
1670
1671        Returns
1672        -------
1673        metric : float
1674            metric value
1675
1676        """
1677        return self.correct / (self.total + 1e-7)
1678
1679    def update(
1680        self,
1681        predicted: torch.Tensor,
1682        target: torch.Tensor,
1683        tags: torch.Tensor = None,
1684    ) -> None:
1685        """Update the internal state (with a batch).
1686
1687        Parameters
1688        ----------
1689        predicted : torch.Tensor
1690            the main prediction tensor generated by the model
1691        ssl_predicted : torch.Tensor
1692            the SSL prediction tensor generated by the model
1693        target : torch.Tensor
1694            the corresponding main target tensor
1695        ssl_target : torch.Tensor
1696            the corresponding SSL target tensor
1697        tags : torch.Tensor
1698            the tensor of meta tags (or `None`, if tags are not given)
1699
1700        """
1701        mask = target != self.ignore_index
1702        self.total += torch.sum(mask)
1703        self.correct += torch.sum((target == predicted)[mask])

Accuracy.

Accuracy(ignore_index=-100)
1651    def __init__(self, ignore_index=-100):
1652        """Initialize the class.
1653
1654        Parameters
1655        ----------
1656        ignore_index: int
1657            the class index that indicates ignored sample
1658
1659        """
1660        super().__init__()
1661        self.ignore_index = ignore_index

Initialize the class.

Parameters

ignore_index: int the class index that indicates ignored sample

ignore_index
def reset(self) -> None:
1663    def reset(self) -> None:
1664        """Reset the internal state (at the beginning of an epoch)."""
1665        self.total = 0
1666        self.correct = 0

Reset the internal state (at the beginning of an epoch).

def calculate(self) -> float:
1668    def calculate(self) -> float:
1669        """Calculate the metric value.
1670
1671        Returns
1672        -------
1673        metric : float
1674            metric value
1675
1676        """
1677        return self.correct / (self.total + 1e-7)

Calculate the metric value.

Returns

metric : float metric value

def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor = None) -> None:
1679    def update(
1680        self,
1681        predicted: torch.Tensor,
1682        target: torch.Tensor,
1683        tags: torch.Tensor = None,
1684    ) -> None:
1685        """Update the internal state (with a batch).
1686
1687        Parameters
1688        ----------
1689        predicted : torch.Tensor
1690            the main prediction tensor generated by the model
1691        ssl_predicted : torch.Tensor
1692            the SSL prediction tensor generated by the model
1693        target : torch.Tensor
1694            the corresponding main target tensor
1695        ssl_target : torch.Tensor
1696            the corresponding SSL target tensor
1697        tags : torch.Tensor
1698            the tensor of meta tags (or `None`, if tags are not given)
1699
1700        """
1701        mask = target != self.ignore_index
1702        self.total += torch.sum(mask)
1703        self.correct += torch.sum((target == predicted)[mask])

Update the internal state (with a batch).

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model ssl_predicted : torch.Tensor the SSL prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor ssl_target : torch.Tensor the corresponding SSL target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)

class Count(dlc2action.metric.base_metric.Metric):
1706class Count(Metric):
1707    """Fraction of samples labeled by the model as a class."""
1708
1709    def __init__(self, classes: Set, exclusive: bool = True):
1710        """Initialize the class.
1711
1712        Parameters
1713        ----------
1714        classes : set
1715            the set of classes to count
1716        exclusive: bool, default True
1717            set to False for multi-label classification tasks
1718
1719        """
1720        super().__init__()
1721        self.classes = classes
1722        self.exclusive = exclusive
1723
1724    def reset(self) -> None:
1725        """Reset the internal state (at the beginning of an epoch)."""
1726        self.count = defaultdict(lambda: 0)
1727        self.total = 0
1728
1729    def update(
1730        self,
1731        predicted: torch.Tensor,
1732        target: torch.Tensor,
1733        tags: torch.Tensor,
1734    ) -> None:
1735        """Update the internal state (with a batch).
1736
1737        Parameters
1738        ----------
1739        predicted : torch.Tensor
1740            the main prediction tensor generated by the model
1741        ssl_predicted : torch.Tensor
1742            the SSL prediction tensor generated by the model
1743        target : torch.Tensor
1744            the corresponding main target tensor
1745        ssl_target : torch.Tensor
1746            the corresponding SSL target tensor
1747        tags : torch.Tensor
1748            the tensor of meta tags (or `None`, if tags are not given)
1749
1750        """
1751        if self.exclusive:
1752            for c in self.classes:
1753                self.count[c] += torch.sum(predicted == c)
1754            self.total += torch.numel(predicted)
1755        else:
1756            for c in self.classes:
1757                self.count[c] += torch.sum(predicted[:, c, :] == 1)
1758            self.total += torch.numel(predicted[:, 0, :])
1759
1760    def calculate(self) -> Dict:
1761        """Calculate the metric (at the end of an epoch).
1762
1763        Returns
1764        -------
1765        result : dict
1766            a dictionary where the keys are class indices and the values are class metric values
1767
1768        """
1769        for c in self.classes:
1770            self.count[c] = self.count[c] / (self.total + 1e-7)
1771        return dict(self.count)

Fraction of samples labeled by the model as a class.

Count(classes: Set, exclusive: bool = True)
1709    def __init__(self, classes: Set, exclusive: bool = True):
1710        """Initialize the class.
1711
1712        Parameters
1713        ----------
1714        classes : set
1715            the set of classes to count
1716        exclusive: bool, default True
1717            set to False for multi-label classification tasks
1718
1719        """
1720        super().__init__()
1721        self.classes = classes
1722        self.exclusive = exclusive

Initialize the class.

Parameters

classes : set the set of classes to count exclusive: bool, default True set to False for multi-label classification tasks

classes
exclusive
def reset(self) -> None:
1724    def reset(self) -> None:
1725        """Reset the internal state (at the beginning of an epoch)."""
1726        self.count = defaultdict(lambda: 0)
1727        self.total = 0

Reset the internal state (at the beginning of an epoch).

def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor) -> None:
1729    def update(
1730        self,
1731        predicted: torch.Tensor,
1732        target: torch.Tensor,
1733        tags: torch.Tensor,
1734    ) -> None:
1735        """Update the internal state (with a batch).
1736
1737        Parameters
1738        ----------
1739        predicted : torch.Tensor
1740            the main prediction tensor generated by the model
1741        ssl_predicted : torch.Tensor
1742            the SSL prediction tensor generated by the model
1743        target : torch.Tensor
1744            the corresponding main target tensor
1745        ssl_target : torch.Tensor
1746            the corresponding SSL target tensor
1747        tags : torch.Tensor
1748            the tensor of meta tags (or `None`, if tags are not given)
1749
1750        """
1751        if self.exclusive:
1752            for c in self.classes:
1753                self.count[c] += torch.sum(predicted == c)
1754            self.total += torch.numel(predicted)
1755        else:
1756            for c in self.classes:
1757                self.count[c] += torch.sum(predicted[:, c, :] == 1)
1758            self.total += torch.numel(predicted[:, 0, :])

Update the internal state (with a batch).

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model ssl_predicted : torch.Tensor the SSL prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor ssl_target : torch.Tensor the corresponding SSL target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)

def calculate(self) -> Dict:
1760    def calculate(self) -> Dict:
1761        """Calculate the metric (at the end of an epoch).
1762
1763        Returns
1764        -------
1765        result : dict
1766            a dictionary where the keys are class indices and the values are class metric values
1767
1768        """
1769        for c in self.classes:
1770            self.count[c] = self.count[c] / (self.total + 1e-7)
1771        return dict(self.count)

Calculate the metric (at the end of an epoch).

Returns

result : dict a dictionary where the keys are class indices and the values are class metric values

class EditDistance(dlc2action.metric.base_metric.Metric):
1774class EditDistance(Metric):
1775    """Edit distance (not advised for training).
1776
1777    Normalized by the length of the sequences.
1778    """
1779
1780    def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None:
1781        """Initialize the class.
1782
1783        Parameters
1784        ----------
1785        ignore_index : int, default -100
1786            the class index that indicates samples that should be ignored
1787        exclusive : bool, default True
1788            set to False for multi-label classification tasks
1789
1790        """
1791        super().__init__()
1792        self.ignore_index = ignore_index
1793        self.exclusive = exclusive
1794
1795    def reset(self) -> None:
1796        """Reset the internal state (at the beginning of an epoch)."""
1797        self.edit_distance = 0
1798        self.total = 0
1799
1800    def update(
1801        self,
1802        predicted: torch.Tensor,
1803        target: torch.Tensor,
1804        tags: torch.Tensor,
1805    ) -> None:
1806        """Update the internal state (with a batch).
1807
1808        Parameters
1809        ----------
1810        predicted : torch.Tensor
1811            the main prediction tensor generated by the model
1812        ssl_predicted : torch.Tensor
1813            the SSL prediction tensor generated by the model
1814        target : torch.Tensor
1815            the corresponding main target tensor
1816        ssl_target : torch.Tensor
1817            the corresponding SSL target tensor
1818        tags : torch.Tensor
1819            the tensor of meta tags (or `None`, if tags are not given)
1820
1821        """
1822        mask = target != self.ignore_index
1823        self.total += torch.sum(mask)
1824        if self.exclusive:
1825            predicted = predicted[mask].flatten()
1826            target = target[mask].flatten()
1827            self.edit_distance += editdistance.eval(
1828                predicted.detach().cpu().numpy(), target.detach().cpu().numpy()
1829            )
1830        else:
1831            for c in range(target.shape[1]):
1832                predicted_class = predicted[:, c, :][mask[:, c, :]].flatten()
1833                target_class = target[:, c, :][mask[:, c, :]].flatten()
1834                self.edit_distance += editdistance.eval(
1835                    predicted_class.detach().cpu().tolist(),
1836                    target_class.detach().cpu().tolist(),
1837                )
1838
1839    def _is_equal(self, a, b):
1840        """Compare while ignoring samples marked with ignore_index."""
1841        if self.ignore_index in [a, b] or a == b:
1842            return True
1843        else:
1844            return False
1845
1846    def calculate(self) -> float:
1847        """Calculate the metric (at the end of an epoch).
1848
1849        Returns
1850        -------
1851        result : float
1852            the metric value
1853
1854        """
1855        return self.edit_distance / (self.total + 1e-7)

Edit distance (not advised for training).

Normalized by the length of the sequences.

EditDistance(ignore_index: int = -100, exclusive: bool = True)
1780    def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None:
1781        """Initialize the class.
1782
1783        Parameters
1784        ----------
1785        ignore_index : int, default -100
1786            the class index that indicates samples that should be ignored
1787        exclusive : bool, default True
1788            set to False for multi-label classification tasks
1789
1790        """
1791        super().__init__()
1792        self.ignore_index = ignore_index
1793        self.exclusive = exclusive

Initialize the class.

Parameters

ignore_index : int, default -100 the class index that indicates samples that should be ignored exclusive : bool, default True set to False for multi-label classification tasks

ignore_index
exclusive
def reset(self) -> None:
1795    def reset(self) -> None:
1796        """Reset the internal state (at the beginning of an epoch)."""
1797        self.edit_distance = 0
1798        self.total = 0

Reset the internal state (at the beginning of an epoch).

def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor) -> None:
1800    def update(
1801        self,
1802        predicted: torch.Tensor,
1803        target: torch.Tensor,
1804        tags: torch.Tensor,
1805    ) -> None:
1806        """Update the internal state (with a batch).
1807
1808        Parameters
1809        ----------
1810        predicted : torch.Tensor
1811            the main prediction tensor generated by the model
1812        ssl_predicted : torch.Tensor
1813            the SSL prediction tensor generated by the model
1814        target : torch.Tensor
1815            the corresponding main target tensor
1816        ssl_target : torch.Tensor
1817            the corresponding SSL target tensor
1818        tags : torch.Tensor
1819            the tensor of meta tags (or `None`, if tags are not given)
1820
1821        """
1822        mask = target != self.ignore_index
1823        self.total += torch.sum(mask)
1824        if self.exclusive:
1825            predicted = predicted[mask].flatten()
1826            target = target[mask].flatten()
1827            self.edit_distance += editdistance.eval(
1828                predicted.detach().cpu().numpy(), target.detach().cpu().numpy()
1829            )
1830        else:
1831            for c in range(target.shape[1]):
1832                predicted_class = predicted[:, c, :][mask[:, c, :]].flatten()
1833                target_class = target[:, c, :][mask[:, c, :]].flatten()
1834                self.edit_distance += editdistance.eval(
1835                    predicted_class.detach().cpu().tolist(),
1836                    target_class.detach().cpu().tolist(),
1837                )

Update the internal state (with a batch).

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model ssl_predicted : torch.Tensor the SSL prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor ssl_target : torch.Tensor the corresponding SSL target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)

def calculate(self) -> float:
1846    def calculate(self) -> float:
1847        """Calculate the metric (at the end of an epoch).
1848
1849        Returns
1850        -------
1851        result : float
1852            the metric value
1853
1854        """
1855        return self.edit_distance / (self.total + 1e-7)

Calculate the metric (at the end of an epoch).

Returns

result : float the metric value

class mAP(dlc2action.metric.base_metric.Metric):
1858class mAP(Metric):
1859    """Mean average precision (segmental) (not advised for training)."""
1860
1861    needs_raw_data = True
1862
1863    def __init__(
1864        self,
1865        exclusive,
1866        num_classes,
1867        average="macro",
1868        iou_threshold=0.5,
1869        threshold_value=0.5,
1870        ignored_classes=None,
1871    ):
1872        """Initialize the class.
1873
1874        Parameters
1875        ----------
1876        exclusive : bool
1877            set to False for multi-label classification tasks
1878        num_classes : int
1879            the number of classes
1880        average : {"macro", "micro", "none"}
1881            the type of averaging to perform
1882        iou_threshold : float, default 0.5
1883            the IoU threshold for matching
1884        threshold_value : float, default 0.5
1885            the threshold value for binarization
1886        ignored_classes : list, default None
1887            the list of classes to ignore
1888
1889        """
1890        if ignored_classes is None:
1891            ignored_classes = []
1892        self.average = average
1893        self.iou_threshold = iou_threshold
1894        self.threshold = threshold_value
1895        self.exclusive = exclusive
1896        self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes]
1897        super().__init__()
1898
1899    def match(self, lst, ratio, ground):
1900        """Given a list of proposals, match them to the ground truth boxes."""
1901        lst = sorted(lst, key=lambda x: x[2])
1902
1903        def overlap(prop, ground):
1904            s_p, e_p, _ = prop
1905            s_g, e_g, _ = ground
1906            return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g))
1907
1908        cos_map = [-1 for x in range(len(lst))]
1909        count_map = [0 for x in range(len(ground))]
1910
1911        for x in range(len(lst)):
1912            for y in range(len(ground)):
1913                if overlap(lst[x], ground[y]) < ratio:
1914                    continue
1915                if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]):
1916                    continue
1917                cos_map[x] = y
1918            if cos_map[x] != -1:
1919                count_map[cos_map[x]] += 1
1920        positive = sum([(x > 0) for x in count_map])
1921        return cos_map, count_map, positive, [x[2] for x in lst]
1922
1923    def reset(self) -> None:
1924        """Reset the internal state (at the beginning of an epoch)."""
1925        self.count_map = defaultdict(lambda: [])
1926        self.positive = defaultdict(lambda: 0)
1927        self.cos_map = defaultdict(lambda: [])
1928        self.confidence = defaultdict(lambda: [])
1929
1930    def calc_pr(self, positive, proposal, ground):
1931        """Get precision."""
1932        if proposal == 0:
1933            return 0, 0
1934        if ground == 0:
1935            return 0, 0
1936        return (1.0 * positive) / proposal, (1.0 * positive) / ground
1937
1938    def calculate(self) -> Union[float, Dict]:
1939        """Calculate the metric (at the end of an epoch)."""
1940        if self.average == "micro":
1941            confidence = []
1942            count_map = []
1943            cos_map = []
1944            positive = sum(self.positive.values())
1945            for key in self.count_map.keys():
1946                confidence += self.confidence[key]
1947                cos_map += list(np.array(self.cos_map[key]) + len(count_map))
1948                count_map += self.count_map[key]
1949            return self.ap(cos_map, count_map, positive, confidence)
1950        results = {
1951            key: self.ap(
1952                self.cos_map[key],
1953                self.count_map[key],
1954                self.positive[key],
1955                self.confidence[key],
1956            )
1957            for key in self.count_map.keys()
1958        }
1959        if self.average == "none":
1960            return results
1961        else:
1962            return float(np.mean(list(results.values())))
1963
1964    def ap(self, cos_map, count_map, positive, confidence):
1965        """Compute average precision."""
1966        indices = np.argsort(confidence)
1967        cos_map = list(np.array(cos_map)[indices])
1968        score = 0
1969        number_proposal = len(cos_map)
1970        number_ground = len(count_map)
1971        old_precision, old_recall = self.calc_pr(
1972            positive, number_proposal, number_ground
1973        )
1974
1975        for x in range(len(cos_map)):
1976            number_proposal -= 1
1977            if cos_map[x] == -1:
1978                continue
1979            count_map[cos_map[x]] -= 1
1980            if count_map[cos_map[x]] == 0:
1981                positive -= 1
1982
1983            precision, recall = self.calc_pr(positive, number_proposal, number_ground)
1984            if precision > old_precision:
1985                old_precision = precision
1986            score += old_precision * (old_recall - recall)
1987            old_recall = recall
1988        return score
1989
1990    def _get_intervals(
1991        self, tensor: torch.Tensor, probability: torch.Tensor = None
1992    ) -> Union[Tuple, torch.Tensor]:
1993        """Get `True` group beginning and end indices from a boolean tensor and average probability over these intervals."""
1994        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
1995        true_indices = torch.where(output)[0]
1996        starts = torch.tensor(
1997            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
1998        )
1999        ends = torch.tensor(
2000            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
2001        )
2002        confidence = torch.tensor(
2003            [probability[indices == i].mean() for i in true_indices]
2004        )
2005        return torch.stack([starts, ends, confidence]).T
2006
2007    def update(
2008        self,
2009        predicted: torch.Tensor,
2010        target: torch.Tensor,
2011        tags: torch.Tensor,
2012    ) -> None:
2013        """Update the state (at the end of each batch)."""
2014        predicted = torch.cat(
2015            [
2016                copy(predicted),
2017                -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device),
2018            ],
2019            dim=-1,
2020        )
2021        target = torch.cat(
2022            [
2023                copy(target),
2024                -100 * torch.ones((*target.shape[:-1], 1)).to(target.device),
2025            ],
2026            dim=-1,
2027        )
2028        num_classes = predicted.shape[1]
2029        predicted = predicted.transpose(1, 2).reshape(-1, num_classes)
2030        if self.exclusive:
2031            target = target.flatten()
2032        else:
2033            target = target.transpose(1, 2).reshape(-1, num_classes)
2034        probability = copy(predicted)
2035        if not self.exclusive:
2036            predicted = (predicted > self.threshold).int()
2037        else:
2038            predicted = torch.max(predicted, 1)[1]
2039        for c in self.classes:
2040            if self.exclusive:
2041                predicted_intervals = self._get_intervals(
2042                    predicted == c, probability=probability[:, c]
2043                )
2044                target_intervals = self._get_intervals(
2045                    target == c, probability=probability[:, c]
2046                )
2047            else:
2048                predicted_intervals = self._get_intervals(
2049                    predicted[:, c] == 1, probability=probability[:, c]
2050                )
2051                target_intervals = self._get_intervals(
2052                    target[:, c] == 1, probability=probability[:, c]
2053                )
2054            cos_map, count_map, positive, confidence = self.match(
2055                predicted_intervals, self.iou_threshold, target_intervals
2056            )
2057            cos_map = np.array(cos_map)
2058            cos_map[cos_map != -1] += len(self.count_map[c])
2059            self.cos_map[c] += list(cos_map)
2060            self.count_map[c] += count_map
2061            self.confidence[c] += confidence
2062            self.positive[c] += positive

Mean average precision (segmental) (not advised for training).

mAP( exclusive, num_classes, average='macro', iou_threshold=0.5, threshold_value=0.5, ignored_classes=None)
1863    def __init__(
1864        self,
1865        exclusive,
1866        num_classes,
1867        average="macro",
1868        iou_threshold=0.5,
1869        threshold_value=0.5,
1870        ignored_classes=None,
1871    ):
1872        """Initialize the class.
1873
1874        Parameters
1875        ----------
1876        exclusive : bool
1877            set to False for multi-label classification tasks
1878        num_classes : int
1879            the number of classes
1880        average : {"macro", "micro", "none"}
1881            the type of averaging to perform
1882        iou_threshold : float, default 0.5
1883            the IoU threshold for matching
1884        threshold_value : float, default 0.5
1885            the threshold value for binarization
1886        ignored_classes : list, default None
1887            the list of classes to ignore
1888
1889        """
1890        if ignored_classes is None:
1891            ignored_classes = []
1892        self.average = average
1893        self.iou_threshold = iou_threshold
1894        self.threshold = threshold_value
1895        self.exclusive = exclusive
1896        self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes]
1897        super().__init__()

Initialize the class.

Parameters

exclusive : bool set to False for multi-label classification tasks num_classes : int the number of classes average : {"macro", "micro", "none"} the type of averaging to perform iou_threshold : float, default 0.5 the IoU threshold for matching threshold_value : float, default 0.5 the threshold value for binarization ignored_classes : list, default None the list of classes to ignore

needs_raw_data = True

If True, dlc2action.task.universal_task.Task will pass raw data to the metric (only primary predict function applied). Otherwise it will pass a prediction for the classes.

average
iou_threshold
threshold
exclusive
classes
def match(self, lst, ratio, ground):
1899    def match(self, lst, ratio, ground):
1900        """Given a list of proposals, match them to the ground truth boxes."""
1901        lst = sorted(lst, key=lambda x: x[2])
1902
1903        def overlap(prop, ground):
1904            s_p, e_p, _ = prop
1905            s_g, e_g, _ = ground
1906            return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g))
1907
1908        cos_map = [-1 for x in range(len(lst))]
1909        count_map = [0 for x in range(len(ground))]
1910
1911        for x in range(len(lst)):
1912            for y in range(len(ground)):
1913                if overlap(lst[x], ground[y]) < ratio:
1914                    continue
1915                if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]):
1916                    continue
1917                cos_map[x] = y
1918            if cos_map[x] != -1:
1919                count_map[cos_map[x]] += 1
1920        positive = sum([(x > 0) for x in count_map])
1921        return cos_map, count_map, positive, [x[2] for x in lst]

Given a list of proposals, match them to the ground truth boxes.

def reset(self) -> None:
1923    def reset(self) -> None:
1924        """Reset the internal state (at the beginning of an epoch)."""
1925        self.count_map = defaultdict(lambda: [])
1926        self.positive = defaultdict(lambda: 0)
1927        self.cos_map = defaultdict(lambda: [])
1928        self.confidence = defaultdict(lambda: [])

Reset the internal state (at the beginning of an epoch).

def calc_pr(self, positive, proposal, ground):
1930    def calc_pr(self, positive, proposal, ground):
1931        """Get precision."""
1932        if proposal == 0:
1933            return 0, 0
1934        if ground == 0:
1935            return 0, 0
1936        return (1.0 * positive) / proposal, (1.0 * positive) / ground

Get precision.

def calculate(self) -> Union[float, Dict]:
1938    def calculate(self) -> Union[float, Dict]:
1939        """Calculate the metric (at the end of an epoch)."""
1940        if self.average == "micro":
1941            confidence = []
1942            count_map = []
1943            cos_map = []
1944            positive = sum(self.positive.values())
1945            for key in self.count_map.keys():
1946                confidence += self.confidence[key]
1947                cos_map += list(np.array(self.cos_map[key]) + len(count_map))
1948                count_map += self.count_map[key]
1949            return self.ap(cos_map, count_map, positive, confidence)
1950        results = {
1951            key: self.ap(
1952                self.cos_map[key],
1953                self.count_map[key],
1954                self.positive[key],
1955                self.confidence[key],
1956            )
1957            for key in self.count_map.keys()
1958        }
1959        if self.average == "none":
1960            return results
1961        else:
1962            return float(np.mean(list(results.values())))

Calculate the metric (at the end of an epoch).

def ap(self, cos_map, count_map, positive, confidence):
1964    def ap(self, cos_map, count_map, positive, confidence):
1965        """Compute average precision."""
1966        indices = np.argsort(confidence)
1967        cos_map = list(np.array(cos_map)[indices])
1968        score = 0
1969        number_proposal = len(cos_map)
1970        number_ground = len(count_map)
1971        old_precision, old_recall = self.calc_pr(
1972            positive, number_proposal, number_ground
1973        )
1974
1975        for x in range(len(cos_map)):
1976            number_proposal -= 1
1977            if cos_map[x] == -1:
1978                continue
1979            count_map[cos_map[x]] -= 1
1980            if count_map[cos_map[x]] == 0:
1981                positive -= 1
1982
1983            precision, recall = self.calc_pr(positive, number_proposal, number_ground)
1984            if precision > old_precision:
1985                old_precision = precision
1986            score += old_precision * (old_recall - recall)
1987            old_recall = recall
1988        return score

Compute average precision.

def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor) -> None:
2007    def update(
2008        self,
2009        predicted: torch.Tensor,
2010        target: torch.Tensor,
2011        tags: torch.Tensor,
2012    ) -> None:
2013        """Update the state (at the end of each batch)."""
2014        predicted = torch.cat(
2015            [
2016                copy(predicted),
2017                -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device),
2018            ],
2019            dim=-1,
2020        )
2021        target = torch.cat(
2022            [
2023                copy(target),
2024                -100 * torch.ones((*target.shape[:-1], 1)).to(target.device),
2025            ],
2026            dim=-1,
2027        )
2028        num_classes = predicted.shape[1]
2029        predicted = predicted.transpose(1, 2).reshape(-1, num_classes)
2030        if self.exclusive:
2031            target = target.flatten()
2032        else:
2033            target = target.transpose(1, 2).reshape(-1, num_classes)
2034        probability = copy(predicted)
2035        if not self.exclusive:
2036            predicted = (predicted > self.threshold).int()
2037        else:
2038            predicted = torch.max(predicted, 1)[1]
2039        for c in self.classes:
2040            if self.exclusive:
2041                predicted_intervals = self._get_intervals(
2042                    predicted == c, probability=probability[:, c]
2043                )
2044                target_intervals = self._get_intervals(
2045                    target == c, probability=probability[:, c]
2046                )
2047            else:
2048                predicted_intervals = self._get_intervals(
2049                    predicted[:, c] == 1, probability=probability[:, c]
2050                )
2051                target_intervals = self._get_intervals(
2052                    target[:, c] == 1, probability=probability[:, c]
2053                )
2054            cos_map, count_map, positive, confidence = self.match(
2055                predicted_intervals, self.iou_threshold, target_intervals
2056            )
2057            cos_map = np.array(cos_map)
2058            cos_map[cos_map != -1] += len(self.count_map[c])
2059            self.cos_map[c] += list(cos_map)
2060            self.count_map[c] += count_map
2061            self.confidence[c] += confidence
2062            self.positive[c] += positive

Update the state (at the end of each batch).