dlc2action.metric.metrics

   1#
   2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
   5#
   6"""
   7Implementations of `dlc2action.metric.base_metric.Metric`
   8"""
   9
  10from typing import Union, Dict, Tuple, List, Set, Any
  11import torch
  12from collections import defaultdict
  13from dlc2action.metric.base_metric import Metric
  14from abc import abstractmethod
  15import editdistance
  16import numpy as np
  17from copy import copy, deepcopy
  18import warnings
  19from sklearn import metrics
  20
  21
  22class _ClassificationMetric(Metric):
  23    """
  24    The base class for all metric that are calculated from true and false negative and positive rates
  25    """
  26
  27    needs_raw_data = True
  28    segmental = False
  29    """
  30    If `True`, the metric will be calculated over segments; otherwise over frames.
  31    """
  32
  33    def __init__(
  34        self,
  35        num_classes: int,
  36        ignore_index: int = -100,
  37        average: str = "macro",
  38        ignored_classes: Set = None,
  39        exclusive: bool = True,
  40        iou_threshold: float = 0.5,
  41        tag_average: str = "micro",
  42        threshold_values: List = None,
  43        integration_interval: int = 0,
  44    ):
  45        """
  46        Initialize the metric
  47
  48        Parameters
  49        ----------
  50        ignore_index : int, default -100
  51            the class index that indicates ignored samples
  52        average: {'macro', 'micro', 'none'}
  53            method for averaging across classes
  54        num_classes : int, optional
  55            number of classes (not necessary if main_class is not None)
  56        ignored_classes : set, optional
  57            a set of class ids to ignore in calculation
  58        exclusive: bool, default True
  59            set to False for multi-label classification tasks
  60        iou_threshold : float, default 0.5
  61            if segmental is true, intervals with IoU larger than this threshold are considered correct
  62        tag_average: {'micro', 'macro', 'none'}
  63            method for averaging across meta tags (if given)
  64        threshold_values : List, optional
  65            a list of float values between 0 and 1 that will be used as decision thresholds
  66        """
  67
  68        super().__init__()
  69        self.ignore_index = ignore_index
  70        self.average = average
  71        self.tag_average = tag_average
  72        self.tags = set()
  73        self.exclusive = exclusive
  74        self.threshold = iou_threshold
  75        self.integration_interval = integration_interval
  76        self.classes = list(range(int(num_classes)))
  77        self.optimize = False
  78        if threshold_values is None:
  79            threshold_values = [0.5]
  80        elif threshold_values == ["optimize"]:
  81            threshold_values = list(np.arange(0.25, 0.75, 0.05))
  82            self.optimize = True
  83        if self.exclusive and threshold_values != [0.5]:
  84            raise ValueError(
  85                "Cannot set threshold values for exclusive classification!"
  86            )
  87        self.threshold_values = threshold_values
  88
  89        if ignored_classes is not None:
  90            for c in ignored_classes:
  91                if c in self.classes:
  92                    self.classes.remove(c)
  93        if len(self.classes) == 0:
  94            warnings.warn("No classes are followed!")
  95
  96    def reset(self) -> None:
  97        """
  98        Reset the intrinsic parameters (at the beginning of an epoch)
  99        """
 100
 101        self.tags = set()
 102        self.tp = defaultdict(lambda: defaultdict(lambda: 0))
 103        self.fp = defaultdict(lambda: defaultdict(lambda: 0))
 104        self.fn = defaultdict(lambda: defaultdict(lambda: 0))
 105        self.tn = defaultdict(lambda: defaultdict(lambda: 0))
 106
 107    def update(
 108        self,
 109        predicted: torch.Tensor,
 110        target: torch.Tensor,
 111        tags: torch.Tensor,
 112    ) -> None:
 113        """
 114        Update the intrinsic parameters (with a batch)
 115
 116        Parameters
 117        ----------
 118        predicted : torch.Tensor
 119            the main prediction tensor generated by the model
 120        ssl_predicted : torch.Tensor
 121            the SSL prediction tensor generated by the model
 122        target : torch.Tensor
 123            the corresponding main target tensor
 124        ssl_target : torch.Tensor
 125            the corresponding SSL target tensor
 126        tags : torch.Tensor
 127            the tensor of meta tags (or `None`, if tags are not given)
 128        """
 129
 130        if self.segmental:
 131            self._update_segmental(predicted, target, tags)
 132        else:
 133            self._update_normal(predicted, target, tags)
 134
 135    def _key(self, tag: torch.Tensor, c: int) -> str:
 136        """
 137        Get a key for the intermediate value dictionaries from tag and class indices
 138        """
 139
 140        if tag is None or self.tag_average == "micro":
 141            return c
 142        else:
 143            return f"tag{int(tag)}_{c}"
 144
 145    def _update_normal(
 146        self,
 147        predicted: torch.Tensor,
 148        target: torch.Tensor,
 149        tags: torch.Tensor,
 150    ) -> None:
 151        """
 152        Update the intrinsic parameters (with a batch), calculating over frames
 153        """
 154
 155        if self.integration_interval != 0:
 156            pr = 0
 157            denom = torch.ones((1, 1, predicted.shape[-1])) * (
 158                2 * self.integration_interval + 1
 159            )
 160            for i in range(self.integration_interval):
 161                pr += torch.cat(
 162                    [
 163                        torch.zeros((predicted.shape[0], predicted.shape[1], i + 1)),
 164                        predicted[:, :, i + 1 :],
 165                    ],
 166                    dim=-1,
 167                )
 168                pr += torch.cat(
 169                    [
 170                        predicted[:, :, : -i - 1],
 171                        torch.zeros((predicted.shape[0], predicted.shape[1], i + 1)),
 172                    ],
 173                    dim=-1,
 174                )
 175                denom[:, :, i] = i + 1
 176                denom[:, :, -i] = i + 1
 177            predicted = pr / denom
 178        if self.exclusive:
 179            predicted = torch.max(predicted, 1)[1]
 180        if self.tag_average == "micro" or tags is None or tags[0] is None:
 181            tag_set = {None}
 182        else:
 183            tag_set = set(tags)
 184            self.tags.update(tag_set)
 185        for thr in self.threshold_values:
 186            for t in tag_set:
 187                if t is None:
 188                    predicted_t = predicted
 189                    target_t = target
 190                else:
 191                    predicted_t = predicted[tags == t]
 192                    target_t = target[tags == t]
 193                for c in self.classes:
 194                    if self.exclusive:
 195                        pos = (predicted_t == c) * (target_t != self.ignore_index)
 196                        neg = (predicted_t != c) * (target_t != self.ignore_index)
 197                        self.tp[self._key(t, c)][thr] += ((target_t == c) * pos).sum()
 198                        self.fp[self._key(t, c)][thr] += ((target_t != c) * pos).sum()
 199                        self.fn[self._key(t, c)][thr] += ((target_t == c) * neg).sum()
 200                        self.tn[self._key(t, c)][thr] += ((target_t != c) * neg).sum()
 201                    else:
 202                        if isinstance(thr, list):
 203                            key = ", ".join(map(str, thr))
 204                            for i, tt in enumerate(thr):
 205                                predicted_t[:, i, :] = (predicted_t[:, i, :] > tt).int()
 206                        else:
 207                            key = thr
 208                            predicted_t = (predicted_t > thr).int()
 209                        pos = predicted_t[:, c, :] * (
 210                            target_t[:, c, :] != self.ignore_index
 211                        )
 212                        neg = (predicted_t[:, c, :] != 1) * (
 213                            target_t[:, c, :] != self.ignore_index
 214                        )
 215                        self.tp[self._key(t, c)][key] += (target_t[:, c, :] * pos).sum()
 216                        self.fp[self._key(t, c)][key] += (
 217                            (target_t[:, c, :] != 1) * pos
 218                        ).sum()
 219                        self.fn[self._key(t, c)][key] += (target_t[:, c, :] * neg).sum()
 220                        self.tn[self._key(t, c)][key] += (
 221                            (target_t[:, c, :] != 1) * neg
 222                        ).sum()
 223
 224    def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor:
 225        """
 226        Get a list of True group beginning and end indices from a boolean tensor
 227        """
 228
 229        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
 230        true_indices = torch.where(output)[0]
 231        starts = torch.tensor(
 232            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
 233        )
 234        ends = torch.tensor(
 235            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
 236        )
 237        return torch.stack([starts, ends]).T
 238
 239    def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor:
 240        """
 241        Get rid of jittering in a non-exclusive classification tensor
 242
 243        First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
 244        `smooth_interval`.
 245        """
 246
 247        for c in self.classes:
 248            intervals = self._get_intervals(tensor[:, c] == 0)
 249            interval_lengths = torch.tensor(
 250                [interval[1] - interval[0] for interval in intervals]
 251            )
 252            short_intervals = intervals[interval_lengths <= smooth_interval]
 253            for start, end in short_intervals:
 254                tensor[start:end, c] = 1
 255            intervals = self._get_intervals(tensor[:, c] == 1)
 256            interval_lengths = torch.tensor(
 257                [interval[1] - interval[0] for interval in intervals]
 258            )
 259            short_intervals = intervals[interval_lengths <= smooth_interval]
 260            for start, end in short_intervals:
 261                tensor[start:end, c] = 0
 262        return tensor
 263
 264    def _sigmoid_threshold_function(
 265        self,
 266        low_threshold: float,
 267        high_threshold: float,
 268        low_length: int,
 269        high_length: int,
 270    ):
 271        """
 272        Generate a sigmoid threshold function
 273
 274        The resulting function outputs an intersection threshold given the length of the interval.
 275        """
 276
 277        a = 2 / (high_length - low_length)
 278        b = 1 - a * high_length
 279        return lambda x: low_threshold + torch.sigmoid(4 * (a * x + b)) * (
 280            high_threshold - low_threshold
 281        )
 282
 283    def _update_segmental(
 284        self,
 285        predicted: torch.tensor,
 286        target: torch.tensor,
 287        tags: torch.Tensor,
 288    ) -> None:
 289        """
 290        Update the intrinsic parameters (with a batch), calculating over segments
 291        """
 292
 293        if self.exclusive:
 294            predicted = torch.max(predicted, 1)[1]
 295        predicted = torch.cat(
 296            [
 297                copy(predicted),
 298                -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device),
 299            ],
 300            dim=-1,
 301        )
 302        target = torch.cat(
 303            [
 304                copy(target),
 305                -100 * torch.ones((*target.shape[:-1], 1)).to(target.device),
 306            ],
 307            dim=-1,
 308        )
 309        if self.exclusive:
 310            predicted = predicted.flatten()
 311            target = target.flatten()
 312        else:
 313            num_classes = predicted.shape[1]
 314            predicted = predicted.transpose(1, 2).reshape(-1, num_classes)
 315            target = target.transpose(1, 2).reshape(-1, num_classes)
 316        if self.tag_average == "micro" or tags is None or tags[0] is None:
 317            tag_set = {None}
 318        else:
 319            tag_set = set(tags)
 320            self.tags.update(tag_set)
 321        for thr in self.threshold_values:
 322            key = thr
 323            for t in tag_set:
 324                if t is None:
 325                    predicted_t = predicted
 326                    target_t = target
 327                else:
 328                    predicted_t = predicted[tags == t]
 329                    target_t = target[tags == t]
 330                if not self.exclusive:
 331                    if isinstance(thr, list):
 332                        for i, tt in enumerate(thr):
 333                            predicted_t[i, :] = (predicted_t[i, :] > tt).int()
 334                        key = ", ".join(map(str, thr))
 335                    else:
 336                        predicted_t = (predicted_t > thr).int()
 337                for c in self.classes:
 338                    if self.exclusive:
 339                        predicted_intervals = self._get_intervals(predicted_t == c)
 340                        target_intervals = self._get_intervals(target_t == c)
 341                    else:
 342                        predicted_intervals = self._get_intervals(
 343                            predicted_t[:, c] == 1
 344                        )
 345                        target_intervals = self._get_intervals(target_t[:, c] == 1)
 346                    true_used = torch.zeros(target_intervals.shape[0])
 347                    for interval in predicted_intervals:
 348                        if len(target_intervals) > 0:
 349                            # Compute IoU against all others
 350                            intersection = torch.minimum(
 351                                interval[1], target_intervals[:, 1]
 352                            ) - torch.maximum(interval[0], target_intervals[:, 0])
 353                            union = torch.maximum(
 354                                interval[1], target_intervals[:, 1]
 355                            ) - torch.minimum(interval[0], target_intervals[:, 0])
 356                            IoU = intersection / union
 357
 358                            # Get the best scoring segment
 359                            idx = IoU.argmax()
 360
 361                            # If the IoU is high enough and the true segment isn't already used
 362                            # Then it is a true positive. Otherwise is it a false positive.
 363                            if IoU[idx] >= self.threshold and not true_used[idx]:
 364                                self.tp[self._key(t, c)][key] += 1
 365                                true_used[idx] = 1
 366                            else:
 367                                self.fp[self._key(t, c)][key] += 1
 368                        else:
 369                            self.fp[self._key(t, c)][key] += 1
 370                    self.fn[self._key(t, c)][key] += len(true_used) - torch.sum(
 371                        true_used
 372                    )
 373
 374    def calculate(self) -> Union[float, Dict]:
 375        """
 376        Calculate the metric (at the end of an epoch)
 377
 378        Returns
 379        -------
 380        result : float | dict
 381            either the single value of the metric or a dictionary where the keys are class indices and the values
 382            are class metric values
 383        """
 384
 385        if self.tag_average == "micro" or self.tags == {None}:
 386            self.tags = {None}
 387        result = {}
 388        self.tags = sorted(list(self.tags))
 389        for tag in self.tags:
 390            if self.average == "macro":
 391                metric_list = []
 392                for c in self.classes:
 393                    metric_list.append(
 394                        self._calculate_metric(
 395                            self.tp[self._key(tag, c)],
 396                            self.fp[self._key(tag, c)],
 397                            self.fn[self._key(tag, c)],
 398                            self.tn[self._key(tag, c)],
 399                        )
 400                    )
 401                if tag is not None:
 402                    tag = int(tag)
 403                result[f"tag{tag}"] = sum(metric_list) / len(metric_list)
 404            elif self.average == "micro":
 405                tp = sum([self.tp[self._key(tag, c)] for c in self.classes])
 406                fn = sum([self.fn[self._key(tag, c)] for c in self.classes])
 407                fp = sum([self.fp[self._key(tag, c)] for c in self.classes])
 408                tn = sum([self.tn[self._key(tag, c)] for c in self.classes])
 409                if tag is not None:
 410                    tag = int(tag)
 411                result[f"tag{tag}"] = self._calculate_metric(tp, fp, fn, tn)
 412            elif self.average == "none":
 413                metric_dict = {}
 414                for c in self.classes:
 415                    metric_dict[self._key(tag, c)] = self._calculate_metric(
 416                        self.tp[self._key(tag, c)],
 417                        self.fp[self._key(tag, c)],
 418                        self.fn[self._key(tag, c)],
 419                        self.tn[self._key(tag, c)],
 420                    )
 421                result.update(metric_dict)
 422            else:
 423                raise ValueError(
 424                    f"The {self.average} averaging method is not available, please choose from "
 425                    f'["none", "micro", "macro"]'
 426                )
 427        if len(self.tags) == 1 and self.average != "none":
 428            tag = self.tags[0]
 429            if tag is not None:
 430                tag = int(tag)
 431            result = result[f"tag{tag}"]
 432        elif self.tag_average == "macro":
 433            if self.average == "none":
 434                r = {}
 435                for c in self.classes:
 436                    r[c] = torch.mean(
 437                        torch.tensor([result[self._key(tag, c)] for tag in self.tags])
 438                    )
 439            else:
 440                r = torch.mean(
 441                    torch.tensor([result[f"tag{int(tag)}"] for tag in self.tags])
 442                )
 443            result = r
 444        return result
 445
 446    @abstractmethod
 447    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 448        """
 449        Calculate the metric value from true and false positive and negative rates
 450
 451        Parameters
 452        ----------
 453        tp : float
 454            true positive rate
 455        fp: float
 456            false positive rate
 457        fn: float
 458            false negative rate
 459        tn: float
 460            true negative rate
 461
 462        Returns
 463        -------
 464        metric : float
 465            metric value
 466        """
 467
 468
 469class PR_AUC(_ClassificationMetric):
 470    """
 471    Area under precision-recall curve (not advised for training)
 472    """
 473
 474    def __init__(
 475        self,
 476        num_classes: int,
 477        ignore_index: int = -100,
 478        average: str = "macro",
 479        ignored_classes: Set = None,
 480        exclusive: bool = False,
 481        tag_average: str = "micro",
 482        threshold_step: float = 0.1,
 483    ):
 484        """
 485        Initialize the metric
 486
 487        Parameters
 488        ----------
 489        ignore_index : int, default -100
 490            the class index that indicates ignored samples
 491        average: {'macro', 'micro', 'none'}
 492            method for averaging across classes
 493        num_classes : int, optional
 494            number of classes (not necessary if main_class is not None)
 495        ignored_classes : set, optional
 496            a set of class ids to ignore in calculation
 497        exclusive: bool, default True
 498            set to False for multi-label classification tasks
 499        tag_average: {'micro', 'macro', 'none'}
 500            method for averaging across meta tags (if given)
 501        threshold_step : float, default 0.1
 502            the decision threshold step
 503        """
 504
 505        if exclusive:
 506            raise ValueError(
 507                "The PR-AUC metric is not implemented for exclusive classification!"
 508            )
 509        super().__init__(
 510            num_classes,
 511            ignore_index,
 512            average,
 513            ignored_classes,
 514            exclusive,
 515            tag_average=tag_average,
 516            threshold_values=list(np.arange(0, 1, threshold_step)),
 517        )
 518
 519    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 520        """
 521        Calculate the metric value from true and false positive and negative rates
 522        """
 523
 524        precisions = []
 525        recalls = []
 526        for k in sorted(self.threshold_values):
 527            precisions.append(tp[k] / (tp[k] + fp[k] + 1e-7))
 528            recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7))
 529        return metrics.auc(x=recalls, y=precisions)
 530
 531
 532class Precision(_ClassificationMetric):
 533    """
 534    Precision
 535    """
 536
 537    def __init__(
 538        self,
 539        num_classes: int,
 540        ignore_index: int = -100,
 541        average: str = "macro",
 542        ignored_classes: Set = None,
 543        exclusive: bool = True,
 544        tag_average: str = "micro",
 545        threshold_value: Union[float, List] = None,
 546    ):
 547        """
 548        Parameters
 549        ----------
 550        ignore_index : int, default -100
 551            the class index that indicates ignored samples
 552        average: {'macro', 'micro', 'none'}
 553            method for averaging across classes
 554        num_classes : int, optional
 555            number of classes (not necessary if main_class is not None)
 556        ignored_classes : set, optional
 557            a set of class ids to ignore in calculation
 558        exclusive: bool, default True
 559            set to False for multi-label classification tasks
 560        tag_average: {'micro', 'macro', 'none'}
 561            method for averaging across meta tags (if given)
 562        threshold_value : float | list, optional
 563            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 564            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 565            under the same index
 566        """
 567
 568        if threshold_value is None:
 569            threshold_value = 0.5
 570        super().__init__(
 571            num_classes,
 572            ignore_index,
 573            average,
 574            ignored_classes,
 575            exclusive,
 576            tag_average=tag_average,
 577            threshold_values=[threshold_value],
 578        )
 579
 580    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 581        """
 582        Calculate the metric value from true and false positive and negative rates
 583        """
 584
 585        k = self.threshold_values[0]
 586        if isinstance(k, list):
 587            k = ", ".join(map(str, k))
 588        return tp[k] / (tp[k] + fp[k] + 1e-7)
 589
 590
 591class SegmentalPrecision(_ClassificationMetric):
 592    """
 593    Segmental precision (not advised for training)
 594    """
 595
 596    segmental = True
 597
 598    def __init__(
 599        self,
 600        num_classes: int,
 601        ignore_index: int = -100,
 602        average: str = "macro",
 603        ignored_classes: Set = None,
 604        exclusive: bool = True,
 605        iou_threshold: float = 0.5,
 606        tag_average: str = "micro",
 607        threshold_value: Union[float, List] = None,
 608    ):
 609        """
 610        Parameters
 611        ----------
 612        ignore_index : int, default -100
 613            the class index that indicates ignored samples
 614        average: {'macro', 'micro', 'none'}
 615            method for averaging across classes
 616        num_classes : int, optional
 617            number of classes (not necessary if main_class is not None)
 618        ignored_classes : set, optional
 619            a set of class ids to ignore in calculation
 620        exclusive: bool, default True
 621            set to False for multi-label classification tasks
 622        iou_threshold : float, default 0.5
 623            if segmental is true, intervals with IoU larger than this threshold are considered correct
 624        tag_average: {'micro', 'macro', 'none'}
 625            method for averaging across meta tags (if given)
 626        threshold_value : float | list, optional
 627            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 628            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 629            under the same index
 630        """
 631
 632        if threshold_value is None:
 633            threshold_value = 0.5
 634        super().__init__(
 635            num_classes,
 636            ignore_index,
 637            average,
 638            ignored_classes,
 639            exclusive,
 640            iou_threshold,
 641            tag_average,
 642            [threshold_value],
 643        )
 644
 645    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 646        """
 647        Calculate the metric value from true and false positive and negative rates
 648        """
 649
 650        k = self.threshold_values[0]
 651        if isinstance(k, list):
 652            k = ", ".join(map(str, k))
 653        return tp[k] / (tp[k] + fp[k] + 1e-7)
 654
 655
 656class Recall(_ClassificationMetric):
 657    """
 658    Recall
 659    """
 660
 661    def __init__(
 662        self,
 663        num_classes: int,
 664        ignore_index: int = -100,
 665        average: str = "macro",
 666        ignored_classes: Set = None,
 667        exclusive: bool = True,
 668        tag_average: str = "micro",
 669        threshold_value: Union[float, List] = None,
 670    ):
 671        """
 672        Parameters
 673        ----------
 674        ignore_index : int, default -100
 675            the class index that indicates ignored samples
 676        average: {'macro', 'micro', 'none'}
 677            method for averaging across classes
 678        num_classes : int, optional
 679            number of classes (not necessary if main_class is not None)
 680        ignored_classes : set, optional
 681            a set of class ids to ignore in calculation
 682        exclusive: bool, default True
 683            set to False for multi-label classification tasks
 684        tag_average: {'micro', 'macro', 'none'}
 685            method for averaging across meta tags (if given)
 686        threshold_value : float | list, optional
 687            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 688            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 689            under the same index
 690        """
 691
 692        if threshold_value is None:
 693            threshold_value = 0.5
 694        super().__init__(
 695            num_classes,
 696            ignore_index,
 697            average,
 698            ignored_classes,
 699            exclusive,
 700            tag_average=tag_average,
 701            threshold_values=[threshold_value],
 702        )
 703
 704    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 705        """
 706        Calculate the metric value from true and false positive and negative rates
 707        """
 708
 709        k = self.threshold_values[0]
 710        if isinstance(k, list):
 711            k = ", ".join(map(str, k))
 712        return tp[k] / (tp[k] + fn[k] + 1e-7)
 713
 714
 715class SegmentalRecall(_ClassificationMetric):
 716    """
 717    Segmental recall (not advised for training)
 718    """
 719
 720    segmental = True
 721
 722    def __init__(
 723        self,
 724        num_classes: int,
 725        ignore_index: int = -100,
 726        average: str = "macro",
 727        ignored_classes: Set = None,
 728        exclusive: bool = True,
 729        iou_threshold: float = 0.5,
 730        tag_average: str = "micro",
 731        threshold_value: Union[float, List] = None,
 732    ):
 733        """
 734        Parameters
 735        ----------
 736        ignore_index : int, default -100
 737            the class index that indicates ignored samples
 738        average: {'macro', 'micro', 'none'}
 739            method for averaging across classes
 740        num_classes : int, optional
 741            number of classes (not necessary if main_class is not None)
 742        ignored_classes : set, optional
 743            a set of class ids to ignore in calculation
 744        exclusive: bool, default True
 745            set to False for multi-label classification tasks
 746        iou_threshold : float, default 0.5
 747            if segmental is true, intervals with IoU larger than this threshold are considered correct
 748        tag_average: {'micro', 'macro', 'none'}
 749            method for averaging across meta tags (if given)
 750        threshold_value : float | list, optional
 751            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 752            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 753            under the same index
 754        """
 755
 756        if threshold_value is None:
 757            threshold_value = 0.5
 758        super().__init__(
 759            num_classes,
 760            ignore_index,
 761            average,
 762            ignored_classes,
 763            exclusive,
 764            iou_threshold,
 765            tag_average,
 766            [threshold_value],
 767        )
 768
 769    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 770        """
 771        Calculate the metric value from true and false positive and negative rates
 772        """
 773
 774        k = self.threshold_values[0]
 775        if isinstance(k, list):
 776            k = ", ".join(map(str, k))
 777        return tp[k] / (tp[k] + fn[k] + 1e-7)
 778
 779
 780class F1(_ClassificationMetric):
 781    """
 782    F1 score
 783    """
 784
 785    def __init__(
 786        self,
 787        num_classes: int,
 788        ignore_index: int = -100,
 789        average: str = "macro",
 790        ignored_classes: Set = None,
 791        exclusive: bool = True,
 792        tag_average: str = "micro",
 793        threshold_value: Union[float, List] = None,
 794        integration_interval: int = 0,
 795    ):
 796        """
 797        Parameters
 798        ----------
 799        ignore_index : int, default -100
 800            the class index that indicates ignored samples
 801        average: {'macro', 'micro', 'none'}
 802            method for averaging across classes
 803        num_classes : int, optional
 804            number of classes (not necessary if main_class is not None)
 805        ignored_classes : set, optional
 806            a set of class ids to ignore in calculation
 807        exclusive: bool, default True
 808            set to False for multi-label classification tasks
 809        tag_average: {'micro', 'macro', 'none'}
 810            method for averaging across meta tags (if given)
 811        threshold_value : float | list, optional
 812            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 813            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 814            under the same index
 815        """
 816
 817        if threshold_value is None:
 818            threshold_value = 0.5
 819        super().__init__(
 820            num_classes,
 821            ignore_index,
 822            average,
 823            ignored_classes,
 824            exclusive,
 825            tag_average=tag_average,
 826            threshold_values=[threshold_value],
 827            integration_interval=integration_interval,
 828        )
 829
 830    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 831        """
 832        Calculate the metric value from true and false positive and negative rates
 833        """
 834
 835        if self.optimize:
 836            scores = []
 837            for k in self.threshold_values:
 838                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 839                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 840                scores.append(2 * recall * precision / (recall + precision + 1e-7))
 841            f1 = max(scores)
 842        else:
 843            k = self.threshold_values[0]
 844            if isinstance(k, list):
 845                k = ", ".join(map(str, k))
 846            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 847            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 848            f1 = 2 * recall * precision / (recall + precision + 1e-7)
 849        return f1
 850
 851
 852class SegmentalF1(_ClassificationMetric):
 853    """
 854    Segmental F1 score (not advised for training)
 855    """
 856
 857    segmental = True
 858
 859    def __init__(
 860        self,
 861        num_classes: int,
 862        ignore_index: int = -100,
 863        average: str = "macro",
 864        ignored_classes: Set = None,
 865        exclusive: bool = True,
 866        iou_threshold: float = 0.5,
 867        tag_average: str = "micro",
 868        threshold_value: Union[float, List] = None,
 869    ):
 870        """
 871        Parameters
 872        ----------
 873        ignore_index : int, default -100
 874            the class index that indicates ignored samples
 875        average: {'macro', 'micro', 'none'}
 876            method for averaging across classes
 877        num_classes : int, optional
 878            number of classes (not necessary if main_class is not None)
 879        ignored_classes : set, optional
 880            a set of class ids to ignore in calculation
 881        exclusive: bool, default True
 882            set to False for multi-label classification tasks
 883        iou_threshold : float, default 0.5
 884            if segmental is true, intervals with IoU larger than this threshold are considered correct
 885        tag_average: {'micro', 'macro', 'none'}
 886            method for averaging across meta tags (if given)
 887        threshold_value : float | list, optional
 888            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 889            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 890            under the same index
 891        """
 892
 893        if threshold_value is None:
 894            threshold_value = 0.5
 895        super().__init__(
 896            num_classes,
 897            ignore_index,
 898            average,
 899            ignored_classes,
 900            exclusive,
 901            iou_threshold,
 902            tag_average,
 903            [threshold_value],
 904        )
 905
 906    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 907        """
 908        Calculate the metric value from true and false positive and negative rates
 909        """
 910
 911        if self.optimize:
 912            scores = []
 913            for k in self.threshold_values:
 914                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 915                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 916                scores.append(2 * recall * precision / (recall + precision + 1e-7))
 917            f1 = max(scores)
 918        else:
 919            k = self.threshold_values[0]
 920            if isinstance(k, list):
 921                k = ", ".join(map(str, k))
 922            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 923            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 924            f1 = 2 * recall * precision / (recall + precision + 1e-7)
 925        return f1
 926
 927
 928class Fbeta(_ClassificationMetric):
 929    """
 930    F-beta score
 931    """
 932
 933    def __init__(
 934        self,
 935        beta: float = 1,
 936        ignore_index: int = -100,
 937        average: str = "macro",
 938        num_classes: int = None,
 939        ignored_classes: Set = None,
 940        tag_average: str = "micro",
 941        exclusive: bool = True,
 942        threshold_value: float = 0.5,
 943    ):
 944        """
 945        Parameters
 946        ----------
 947        beta : float, default 1
 948            the beta parameter
 949        ignore_index : int, default -100
 950            the class index that indicates ignored samples
 951        average: {'macro', 'micro', 'none'}
 952            method for averaging across classes
 953        num_classes : int, optional
 954            number of classes (not necessary if main_class is not None)
 955        ignored_classes : set, optional
 956            a set of class ids to ignore in calculation
 957        exclusive: bool, default True
 958            set to False for multi-label classification tasks
 959        tag_average: {'micro', 'macro', 'none'}
 960            method for averaging across meta tags (if given)
 961        threshold_value : float | list, optional
 962            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 963            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 964            under the same index
 965        """
 966
 967        if threshold_value is None:
 968            threshold_value = 0.5
 969        self.beta2 = beta**2
 970        super().__init__(
 971            num_classes,
 972            ignore_index,
 973            average,
 974            ignored_classes,
 975            exclusive,
 976            tag_average=tag_average,
 977            threshold_values=[threshold_value],
 978        )
 979
 980    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 981        """
 982        Calculate the metric value from true and false positive and negative rates
 983        """
 984
 985        if self.optimize:
 986            scores = []
 987            for k in self.threshold_values:
 988                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 989                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 990                scores.append(
 991                    (
 992                        (1 + self.beta2)
 993                        * precision
 994                        * recall
 995                        / (self.beta2 * precision + recall + 1e-7)
 996                    )
 997                )
 998            f1 = max(scores)
 999        else:
1000            k = self.threshold_values[0]
1001            if isinstance(k, list):
1002                k = ", ".join(map(str, k))
1003            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1004            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1005            f1 = (
1006                (1 + self.beta2)
1007                * precision
1008                * recall
1009                / (self.beta2 * precision + recall + 1e-7)
1010            )
1011        return f1
1012
1013
1014class SegmentalFbeta(_ClassificationMetric):
1015    """
1016    Segmental F-beta score (not advised for training)
1017    """
1018
1019    segmental = True
1020
1021    def __init__(
1022        self,
1023        beta: float = 1,
1024        ignore_index: int = -100,
1025        average: str = "macro",
1026        num_classes: int = None,
1027        ignored_classes: Set = None,
1028        iou_threshold: float = 0.5,
1029        tag_average: str = "micro",
1030        exclusive: bool = True,
1031        threshold_value: float = 0.5,
1032    ):
1033        """
1034        Parameters
1035        ----------
1036        beta : float, default 1
1037            the beta parameter
1038        ignore_index : int, default -100
1039            the class index that indicates ignored samples
1040        average: {'macro', 'micro', 'none'}
1041            method for averaging across classes
1042        num_classes : int, optional
1043            number of classes (not necessary if main_class is not None)
1044        ignored_classes : set, optional
1045            a set of class ids to ignore in calculation
1046        exclusive: bool, default True
1047            set to False for multi-label classification tasks
1048        iou_threshold : float, default 0.5
1049            if segmental is true, intervals with IoU larger than this threshold are considered correct
1050        tag_average: {'micro', 'macro', 'none'}
1051            method for averaging across meta tags (if given)
1052        threshold_value : float | list, optional
1053            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
1054            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
1055            under the same index
1056        """
1057
1058        if threshold_value is None:
1059            threshold_value = 0.5
1060        self.beta2 = beta**2
1061        super().__init__(
1062            num_classes,
1063            ignore_index,
1064            average,
1065            ignored_classes,
1066            exclusive,
1067            iou_threshold=iou_threshold,
1068            tag_average=tag_average,
1069            threshold_values=[threshold_value],
1070        )
1071
1072    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1073        """
1074        Calculate the metric value from true and false positive and negative rates
1075        """
1076
1077        if self.optimize:
1078            scores = []
1079            for k in self.threshold_values:
1080                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1081                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1082                scores.append(
1083                    (
1084                        (1 + self.beta2)
1085                        * precision
1086                        * recall
1087                        / (self.beta2 * precision + recall + 1e-7)
1088                    )
1089                )
1090            f1 = max(scores)
1091        else:
1092            k = self.threshold_values[0]
1093            if isinstance(k, list):
1094                k = ", ".join(map(str, k))
1095            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1096            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1097            f1 = (
1098                (1 + self.beta2)
1099                * precision
1100                * recall
1101                / (self.beta2 * precision + recall + 1e-7)
1102            )
1103        return f1
1104
1105
1106class _SemiSegmentalMetric(_ClassificationMetric):
1107    def __init__(
1108        self,
1109        num_classes: int,
1110        ignore_index: int = -100,
1111        ignored_classes: Set = None,
1112        exclusive: bool = True,
1113        average: str = "macro",
1114        tag_average: str = "micro",
1115        delta: int = 0,
1116        smooth_interval: int = 0,
1117        iou_threshold_long: float = 0.5,
1118        iou_threshold_short: float = 0.5,
1119        short_length: int = 30,
1120        long_length: int = 300,
1121        threshold_values: List = None,
1122    ) -> None:
1123        """
1124        Parameters
1125        ----------
1126        num_classes : int
1127            the number of classes in the dataset
1128        ignore_index : int, default -100
1129            the ground truth label to ignore
1130        ignored_classes : set, optional
1131            the class indices to ignore in computation
1132        exclusive : bool, default True
1133            `False` for multi-label classification tasks
1134        average : {"macro", "micro", "none"}
1135            the method to average the results over classes
1136        tag_average : {"macro", "micro", "none"}
1137            the method to average the results over meta tags (if given)
1138        delta : int, default 0
1139            the number of frames to add to each ground truth interval before computing the intersection,
1140            see description of the class for details
1141        smooth_interval : int, default 0
1142            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1143            see description of the class for details
1144        iou_threshold_long : float, default 0.5
1145            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1146            see description of the class for details
1147        iou_threshold_short : float, default 0.5
1148            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1149            see description of the class for details
1150        short_length : int, default 30
1151            the threshold number of frames for short intervals that will have an intersection threshold of
1152            `iou_threshold_short`, see description of the class for details
1153        long_length : int, default 300
1154            the threshold number of frames for long intervals that will have an intersection threshold of
1155            `iou_threshold_long`, see description of the class for details
1156        """
1157
1158        if smooth_interval > 0 and exclusive:
1159            warnings.warn(
1160                "Smoothing is not implemented for datasets with exclusive classification! Setting smooth_interval to 0..."
1161            )
1162        if threshold_values == [None]:
1163            threshold_values = [0.5]
1164        super().__init__(
1165            num_classes,
1166            ignore_index,
1167            average,
1168            ignored_classes,
1169            exclusive,
1170            tag_average=tag_average,
1171            threshold_values=threshold_values,
1172        )
1173        self.delta = delta
1174        self.random_sampling = False
1175        self.smooth_interval = smooth_interval
1176        self.threshold_function = self._sigmoid_threshold_function(
1177            low_threshold=iou_threshold_short,
1178            high_threshold=iou_threshold_long,
1179            low_length=short_length,
1180            high_length=long_length,
1181        )
1182
1183    def update(
1184        self,
1185        predicted: torch.Tensor,
1186        target: torch.Tensor,
1187        tags: torch.Tensor,
1188    ) -> None:
1189        """
1190        Update the intrinsic parameters (with a batch)
1191
1192        Parameters
1193        ----------
1194        predicted : torch.Tensor
1195            the main prediction tensor generated by the model
1196        ssl_predicted : torch.Tensor
1197            the SSL prediction tensor generated by the model
1198        target : torch.Tensor
1199            the corresponding main target tensor
1200        ssl_target : torch.Tensor
1201            the corresponding SSL target tensor
1202        tags : torch.Tensor
1203            the tensor of meta tags (or `None`, if tags are not given)
1204        """
1205
1206        if self.exclusive:
1207            predicted = torch.max(predicted, 1)[1]
1208        predicted = torch.cat(
1209            [
1210                copy(predicted),
1211                -100
1212                * torch.ones((*predicted.shape[:-1], self.delta + 1)).to(
1213                    predicted.device
1214                ),
1215            ],
1216            dim=-1,
1217        )
1218        target = torch.cat(
1219            [
1220                copy(target),
1221                -100
1222                * torch.ones((*target.shape[:-1], self.delta + 1)).to(target.device),
1223            ],
1224            dim=-1,
1225        )
1226        if self.exclusive:
1227            predicted = predicted.flatten()
1228            target = target.flatten()
1229        else:
1230            predicted = predicted.transpose(1, 2).reshape(-1, target.shape[1])
1231            target = target.transpose(1, 2).reshape(-1, predicted.shape[1])
1232        if self.tag_average == "micro" or tags is None or tags[0] is None:
1233            tag_set = {None}
1234        else:
1235            tag_set = set(tags)
1236            self.tags.update(tag_set)
1237        if self.smooth_interval > 0:
1238            target = self._smooth(target, self.smooth_interval)
1239        for thr in self.threshold_values:
1240            key = thr
1241            for t in tag_set:
1242                if t is None:
1243                    predicted_t = deepcopy(predicted)
1244                    target_t = target
1245                else:
1246                    predicted_t = deepcopy(predicted[tags == t])
1247                    target_t = target[tags == t]
1248                if not self.exclusive:
1249                    if isinstance(thr, list):
1250                        for i, tt in enumerate(thr):
1251                            predicted_t[i, :] = (predicted_t[i, :] > tt).int()
1252                        key = ", ".join(map(str, thr))
1253                    else:
1254                        predicted_t = (predicted_t > thr).int()
1255                if self.smooth_interval > 0:
1256                    predicted_t = self._smooth(predicted_t, self.smooth_interval)
1257                for c in self.classes:
1258                    if not self.random_sampling:
1259                        if self.exclusive:
1260                            target_intervals = self._get_intervals(target_t == c)
1261                        else:
1262                            target_intervals = self._get_intervals(target_t[:, c] == 1)
1263                        target_lengths = [
1264                            end - start for start, end in target_intervals
1265                        ]
1266                        target_intervals = [
1267                            [
1268                                max(start - self.delta, 0),
1269                                min(end + self.delta, len(target_t)),
1270                            ]
1271                            for start, end in target_intervals
1272                        ]
1273                    else:
1274                        if self.exclusive:
1275                            target_arr = target_t == c
1276                        else:
1277                            target_arr = target_t[:, c] == 1
1278                        target_points = torch.where(target_arr)[0][::20]
1279                        target_intervals = []
1280                        for p in target_points:
1281                            target_intervals.append(
1282                                [
1283                                    max(0, p - self.delta),
1284                                    min(len(target_arr), p + self.delta),
1285                                ]
1286                            )
1287                        target_lengths = [
1288                            end - start for start, end in target_intervals
1289                        ]
1290                    if self.exclusive:
1291                        predicted_arr = predicted_t == c
1292                    else:
1293                        predicted_arr = predicted_t[:, c] == 1
1294                    for interval, l in zip(target_intervals, target_lengths):
1295                        intersection = torch.sum(
1296                            predicted_arr[interval[0] : interval[1]]
1297                        )
1298                        IoU = intersection / l
1299                        if IoU >= self.threshold_function(l):
1300                            self.tp[self._key(t, c)][key] += 1
1301                        else:
1302                            self.fn[self._key(t, c)][key] += 1
1303
1304                    if not self.random_sampling:
1305                        if self.exclusive:
1306                            predicted_intervals = self._get_intervals(predicted_t == c)
1307                        else:
1308                            predicted_intervals = self._get_intervals(
1309                                predicted_t[:, c] == 1
1310                            )
1311                        predicted_intervals_delta = [
1312                            [
1313                                max(start - self.delta, 0),
1314                                min(end + self.delta, len(target_t)),
1315                            ]
1316                            for start, end in predicted_intervals
1317                        ]
1318                    else:
1319                        if self.exclusive:
1320                            predicted_arr = predicted_t == c
1321                        else:
1322                            predicted_arr = predicted_t[:, c] == 1
1323                        predicted_points = torch.where(predicted_arr)[0][::20]
1324                        predicted_intervals = []
1325                        for p in predicted_points:
1326                            predicted_intervals.append(
1327                                [
1328                                    max(0, p - self.delta),
1329                                    min(len(predicted_arr), p + self.delta),
1330                                ]
1331                            )
1332                        predicted_intervals_delta = predicted_intervals
1333                    if self.exclusive:
1334                        target_arr = target_t == c
1335                        target_arr[target_t == -100] = -100
1336                    else:
1337                        target_arr = target_t[:, c]
1338                    for interval, interval_delta in zip(
1339                        predicted_intervals, predicted_intervals_delta
1340                    ):
1341                        if torch.sum(
1342                            target_arr[interval_delta[0] : interval_delta[1]] != -100
1343                        ) < 0.3 * (interval[1] - interval[0]):
1344                            continue
1345                        l = torch.sum(target_arr[interval[0] : interval[1]] != -100)
1346                        intersection = torch.sum(
1347                            target_arr[interval_delta[0] : interval_delta[1]] == 1
1348                        )
1349                        IoU = intersection / l
1350                        if IoU >= self.threshold_function(l):
1351                            self.tn[self._key(t, c)][key] += 1
1352                        else:
1353                            self.fp[self._key(t, c)][key] += 1
1354
1355
1356class SemiSegmentalRecall(_SemiSegmentalMetric):
1357    """
1358    Semi-segmental recall (not advised for training)
1359
1360    A metric in-between segmental and frame-wise recall.
1361
1362    This metric follows the following algorithm:
1363    1) smooth over too-short intervals, both in ground truth and in prediction (first remove
1364        groups of zeros shorter than `smooth_interval` and then do the same with groups of ones),
1365    2) add `delta` frames to each ground truth interval at both ends and count the number of predicted
1366        positive frames at the resulting intervals (intersection),
1367    3) calculate the threshold for each interval as
1368        `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where
1369        `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval
1370        before `delta` was added,
1371    4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`),
1372        and otherwise as false negative (`FN`),
1373    5) the final metric value is computed as `TP / (TP + FN)`.
1374    """
1375
1376    def __init__(
1377        self,
1378        num_classes: int,
1379        ignore_index: int = -100,
1380        ignored_classes: Set = None,
1381        exclusive: bool = True,
1382        average: str = "macro",
1383        tag_average: str = "micro",
1384        delta: int = 0,
1385        smooth_interval: int = 0,
1386        iou_threshold_long: float = 0.5,
1387        iou_threshold_short: float = 0.5,
1388        short_length: int = 30,
1389        long_length: int = 300,
1390        threshold_value: Union[float, List] = None,
1391    ) -> None:
1392        """
1393        Parameters
1394        ----------
1395        num_classes : int
1396            the number of classes in the dataset
1397        ignore_index : int, default -100
1398            the ground truth label to ignore
1399        ignored_classes : set, optional
1400            the class indices to ignore in computation
1401        exclusive : bool, default True
1402            `False` for multi-label classification tasks
1403        average : {"macro", "micro", "none"}
1404            the method to average the results over classes
1405        tag_average : {"macro", "micro", "none"}
1406            the method to average the results over meta tags (if given)
1407        delta : int, default 0
1408            the number of frames to add to each ground truth interval before computing the intersection,
1409            see description of the class for details
1410        smooth_interval : int, default 0
1411            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1412            see description of the class for details
1413        iou_threshold_long : float, default 0.5
1414            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1415            see description of the class for details
1416        iou_threshold_short : float, default 0.5
1417            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1418            see description of the class for details
1419        short_length : int, default 30
1420            the threshold number of frames for short intervals that will have an intersection threshold of
1421            `iou_threshold_short`, see description of the class for details
1422        long_length : int, default 300
1423            the threshold number of frames for long intervals that will have an intersection threshold of
1424            `iou_threshold_long`, see description of the class for details
1425        """
1426
1427        super().__init__(
1428            num_classes,
1429            ignore_index,
1430            ignored_classes,
1431            exclusive,
1432            average,
1433            tag_average,
1434            delta,
1435            smooth_interval,
1436            iou_threshold_long,
1437            iou_threshold_short,
1438            short_length,
1439            long_length,
1440            [threshold_value],
1441        )
1442
1443    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1444        """
1445        Calculate the metric value from true and false positive and negative rates
1446        """
1447
1448        k = self.threshold_values[0]
1449        if isinstance(k, list):
1450            k = ", ".join(map(str, k))
1451        return tp[k] / (tp[k] + fn[k] + 1e-7)
1452
1453
1454class SemiSegmentalPrecision(_SemiSegmentalMetric):
1455    """
1456    Semi-segmental precision (not advised for training)
1457
1458    A metric in-between segmental and frame-wise precision.
1459
1460    This metric follows the following algorithm:
1461    1) smooth over too-short intervals, both in ground truth and in prediction (first remove
1462        groups of zeros shorter than `smooth_interval` and then do the same with groups of ones),
1463    2) add `delta` frames to each predicted interval at both ends and count the number of ground truth
1464        positive frames at the resulting intervals (intersection),
1465    3) calculate the threshold for each interval as
1466        `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where
1467        `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval
1468        before `delta` was added,
1469    4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`),
1470        and otherwise as false positive (`FP`),
1471    5) the final metric value is computed as `TP / (TP + FP)`.
1472    """
1473
1474    def __init__(
1475        self,
1476        num_classes: int,
1477        ignore_index: int = -100,
1478        ignored_classes: Set = None,
1479        exclusive: bool = True,
1480        average: str = "macro",
1481        tag_average: str = "micro",
1482        delta: int = 0,
1483        smooth_interval: int = 0,
1484        iou_threshold_long: float = 0.5,
1485        iou_threshold_short: float = 0.5,
1486        short_length: int = 30,
1487        long_length: int = 300,
1488        threshold_value: Union[float, List] = None,
1489    ) -> None:
1490        """
1491        Parameters
1492        ----------
1493        num_classes : int
1494            the number of classes in the dataset
1495        ignore_index : int, default -100
1496            the ground truth label to ignore
1497        ignored_classes : set, optional
1498            the class indices to ignore in computation
1499        exclusive : bool, default True
1500            `False` for multi-label classification tasks
1501        average : {"macro", "micro", "none"}
1502            the method to average the results over classes
1503        tag_average : {"macro", "micro", "none"}
1504            the method to average the results over meta tags (if given)
1505        delta : int, default 0
1506            the number of frames to add to each ground truth interval before computing the intersection,
1507            see description of the class for details
1508        smooth_interval : int, default 0
1509            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1510            see description of the class for details
1511        iou_threshold_long : float, default 0.5
1512            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1513            see description of the class for details
1514        iou_threshold_short : float, default 0.5
1515            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1516            see description of the class for details
1517        short_length : int, default 30
1518            the threshold number of frames for short intervals that will have an intersection threshold of
1519            `iou_threshold_short`, see description of the class for details
1520        long_length : int, default 300
1521            the threshold number of frames for long intervals that will have an intersection threshold of
1522            `iou_threshold_long`, see description of the class for details
1523        """
1524
1525        super().__init__(
1526            num_classes,
1527            ignore_index,
1528            ignored_classes,
1529            exclusive,
1530            average,
1531            tag_average,
1532            delta,
1533            smooth_interval,
1534            iou_threshold_long,
1535            iou_threshold_short,
1536            short_length,
1537            long_length,
1538            [threshold_value],
1539        )
1540
1541    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1542        """
1543        Calculate the metric value from true and false positive and negative rates
1544        """
1545
1546        k = self.threshold_values[0]
1547        if isinstance(k, list):
1548            k = ", ".join(map(str, k))
1549        return tn[k] / (tn[k] + fp[k] + 1e-7)
1550
1551
1552class SemiSegmentalF1(_SemiSegmentalMetric):
1553    """
1554    The F1 score for semi-segmental recall and precision (not advised for training)
1555    """
1556
1557    def __init__(
1558        self,
1559        num_classes: int,
1560        ignore_index: int = -100,
1561        ignored_classes: Set = None,
1562        exclusive: bool = True,
1563        average: str = "macro",
1564        tag_average: str = "micro",
1565        delta: int = 0,
1566        smooth_interval: int = 0,
1567        iou_threshold_long: float = 0.5,
1568        iou_threshold_short: float = 0.5,
1569        short_length: int = 30,
1570        long_length: int = 300,
1571        threshold_value: Union[float, List] = None,
1572    ) -> None:
1573        """
1574        Parameters
1575        ----------
1576        num_classes : int
1577            the number of classes in the dataset
1578        ignore_index : int, default -100
1579            the ground truth label to ignore
1580        ignored_classes : set, optional
1581            the class indices to ignore in computation
1582        exclusive : bool, default True
1583            `False` for multi-label classification tasks
1584        average : {"macro", "micro", "none"}
1585            the method to average the results over classes
1586        tag_average : {"macro", "micro", "none"}
1587            the method to average the results over meta tags (if given)
1588        delta : int, default 0
1589            the number of frames to add to each ground truth interval before computing the intersection,
1590            see description of the class for details
1591        smooth_interval : int, default 0
1592            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1593            see description of the class for details
1594        iou_threshold_long : float, default 0.5
1595            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1596            see description of the class for details
1597        iou_threshold_short : float, default 0.5
1598            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1599            see description of the class for details
1600        short_length : int, default 30
1601            the threshold number of frames for short intervals that will have an intersection threshold of
1602            `iou_threshold_short`, see description of the class for details
1603        long_length : int, default 300
1604            the threshold number of frames for long intervals that will have an intersection threshold of
1605            `iou_threshold_long`, see description of the class for details
1606        """
1607
1608        super().__init__(
1609            num_classes,
1610            ignore_index,
1611            ignored_classes,
1612            exclusive,
1613            average,
1614            tag_average,
1615            delta,
1616            smooth_interval,
1617            iou_threshold_long,
1618            iou_threshold_short,
1619            short_length,
1620            long_length,
1621            [threshold_value],
1622        )
1623
1624    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1625        """
1626        Calculate the metric value from true and false positive and negative rates
1627        """
1628
1629        if self.optimize:
1630            scores = []
1631            for k in self.threshold_values:
1632                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1633                precision = tn[k] / (tn[k] + fp[k] + 1e-7)
1634                scores.append(2 * recall * precision / (recall + precision + 1e-7))
1635            f1 = max(scores)
1636        else:
1637            k = self.threshold_values[0]
1638            if isinstance(k, list):
1639                k = ", ".join(map(str, k))
1640            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1641            precision = tn[k] / (tn[k] + fp[k] + 1e-7)
1642            f1 = 2 * recall * precision / (recall + precision + 1e-7)
1643        return f1
1644
1645
1646class SemiSegmentalPR_AUC(_SemiSegmentalMetric):
1647    """
1648    The area under the precision-recall curve for semi-segmental metrics (not advised for training)
1649    """
1650
1651    def __init__(
1652        self,
1653        num_classes: int,
1654        ignore_index: int = -100,
1655        ignored_classes: Set = None,
1656        exclusive: bool = True,
1657        average: str = "macro",
1658        tag_average: str = "micro",
1659        delta: int = 0,
1660        smooth_interval: int = 0,
1661        iou_threshold_long: float = 0.5,
1662        iou_threshold_short: float = 0.5,
1663        short_length: int = 30,
1664        long_length: int = 300,
1665        threshold_step: float = 0.1,
1666    ) -> None:
1667        super().__init__(
1668            num_classes,
1669            ignore_index,
1670            ignored_classes,
1671            exclusive,
1672            average,
1673            tag_average,
1674            delta,
1675            smooth_interval,
1676            iou_threshold_long,
1677            iou_threshold_short,
1678            short_length,
1679            long_length,
1680            list(np.arange(0, 1, threshold_step)),
1681        )
1682
1683    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1684        """
1685        Calculate the metric value from true and false positive and negative rates
1686        """
1687
1688        precisions = []
1689        recalls = []
1690        for k in sorted(self.threshold_values):
1691            precisions.append(tn[k] / (tn[k] + fp[k] + 1e-7))
1692            recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7))
1693        return metrics.auc(x=recalls, y=precisions)
1694
1695
1696class Accuracy(Metric):
1697    """
1698    Accuracy
1699    """
1700
1701    def __init__(self, ignore_index=-100):
1702        """
1703        Parameters
1704        ----------
1705        ignore_index: int
1706            the class index that indicates ignored sample
1707        """
1708
1709        super().__init__()
1710        self.ignore_index = ignore_index
1711
1712    def reset(self) -> None:
1713        """
1714        Reset the intrinsic parameters (at the beginning of an epoch)
1715        """
1716
1717        self.total = 0
1718        self.correct = 0
1719
1720    def calculate(self) -> float:
1721        """
1722        Calculate the metric value
1723
1724        Returns
1725        -------
1726        metric : float
1727            metric value
1728        """
1729
1730        return self.correct / (self.total + 1e-7)
1731
1732    def update(
1733        self,
1734        predicted: torch.Tensor,
1735        target: torch.Tensor,
1736        tags: torch.Tensor = None,
1737    ) -> None:
1738        """
1739        Update the intrinsic parameters (with a batch)
1740
1741        Parameters
1742        ----------
1743        predicted : torch.Tensor
1744            the main prediction tensor generated by the model
1745        ssl_predicted : torch.Tensor
1746            the SSL prediction tensor generated by the model
1747        target : torch.Tensor
1748            the corresponding main target tensor
1749        ssl_target : torch.Tensor
1750            the corresponding SSL target tensor
1751        tags : torch.Tensor
1752            the tensor of meta tags (or `None`, if tags are not given)
1753        """
1754
1755        mask = target != self.ignore_index
1756        self.total += torch.sum(mask)
1757        self.correct += torch.sum((target == predicted)[mask])
1758
1759
1760class Count(Metric):
1761    """
1762    Fraction of samples labeled by the model as a class
1763    """
1764
1765    def __init__(self, classes: Set, exclusive: bool = True):
1766        """
1767        Parameters
1768        ----------
1769        classes : set
1770            the set of classes to count
1771        exclusive: bool, default True
1772            set to False for multi-label classification tasks
1773        """
1774
1775        super().__init__()
1776        self.classes = classes
1777        self.exclusive = exclusive
1778
1779    def reset(self) -> None:
1780        """
1781        Reset the intrinsic parameters (at the beginning of an epoch)
1782        """
1783
1784        self.count = defaultdict(lambda: 0)
1785        self.total = 0
1786
1787    def update(
1788        self,
1789        predicted: torch.Tensor,
1790        target: torch.Tensor,
1791        tags: torch.Tensor,
1792    ) -> None:
1793        """
1794        Update the intrinsic parameters (with a batch)
1795
1796        Parameters
1797        ----------
1798        predicted : torch.Tensor
1799            the main prediction tensor generated by the model
1800        ssl_predicted : torch.Tensor
1801            the SSL prediction tensor generated by the model
1802        target : torch.Tensor
1803            the corresponding main target tensor
1804        ssl_target : torch.Tensor
1805            the corresponding SSL target tensor
1806        tags : torch.Tensor
1807            the tensor of meta tags (or `None`, if tags are not given)
1808        """
1809
1810        if self.exclusive:
1811            for c in self.classes:
1812                self.count[c] += torch.sum(predicted == c)
1813            self.total += torch.numel(predicted)
1814        else:
1815            for c in self.classes:
1816                self.count[c] += torch.sum(predicted[:, c, :] == 1)
1817            self.total += torch.numel(predicted[:, 0, :])
1818
1819    def calculate(self) -> Dict:
1820        """
1821        Calculate the metric (at the end of an epoch)
1822
1823        Returns
1824        -------
1825        result : dict
1826            a dictionary where the keys are class indices and the values are class metric values
1827        """
1828
1829        for c in self.classes:
1830            self.count[c] = self.count[c] / (self.total + 1e-7)
1831        return dict(self.count)
1832
1833
1834class EditDistance(Metric):
1835    """
1836    Edit distance (not advised for training)
1837
1838    Normalized by the length of the sequences
1839    """
1840
1841    def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None:
1842        """
1843        Parameters
1844        ----------
1845        ignore_index : int, default -100
1846            the class index that indicates samples that should be ignored
1847        exclusive : bool, default True
1848            set to False for multi-label classification tasks
1849        """
1850
1851        super().__init__()
1852        self.ignore_index = ignore_index
1853        self.exclusive = exclusive
1854
1855    def reset(self) -> None:
1856        """
1857        Reset the intrinsic parameters (at the beginning of an epoch)
1858        """
1859
1860        self.edit_distance = 0
1861        self.total = 0
1862
1863    def update(
1864        self,
1865        predicted: torch.Tensor,
1866        target: torch.Tensor,
1867        tags: torch.Tensor,
1868    ) -> None:
1869        """
1870        Update the intrinsic parameters (with a batch)
1871
1872        Parameters
1873        ----------
1874        predicted : torch.Tensor
1875            the main prediction tensor generated by the model
1876        ssl_predicted : torch.Tensor
1877            the SSL prediction tensor generated by the model
1878        target : torch.Tensor
1879            the corresponding main target tensor
1880        ssl_target : torch.Tensor
1881            the corresponding SSL target tensor
1882        tags : torch.Tensor
1883            the tensor of meta tags (or `None`, if tags are not given)
1884        """
1885
1886        mask = target != self.ignore_index
1887        self.total += torch.sum(mask)
1888        if self.exclusive:
1889            predicted = predicted[mask].flatten()
1890            target = target[mask].flatten()
1891            self.edit_distance += editdistance.eval(
1892                predicted.detach().cpu().numpy(), target.detach().cpu().numpy()
1893            )
1894        else:
1895            for c in range(target.shape[1]):
1896                predicted_class = predicted[:, c, :][mask[:, c, :]].flatten()
1897                target_class = target[:, c, :][mask[:, c, :]].flatten()
1898                self.edit_distance += editdistance.eval(
1899                    predicted_class.detach().cpu().tolist(),
1900                    target_class.detach().cpu().tolist(),
1901                )
1902
1903    def _is_equal(self, a, b):
1904        """
1905        Compare while ignoring samples marked with ignore_index
1906        """
1907
1908        if self.ignore_index in [a, b] or a == b:
1909            return True
1910        else:
1911            return False
1912
1913    def calculate(self) -> float:
1914        """
1915        Calculate the metric (at the end of an epoch)
1916
1917        Returns
1918        -------
1919        result : float
1920            the metric value
1921        """
1922
1923        return self.edit_distance / (self.total + 1e-7)
1924
1925
1926class PKU_mAP(Metric):
1927    """
1928    Mean average precision (segmental) (not advised for training)
1929    """
1930
1931    needs_raw_data = True
1932
1933    def __init__(
1934        self,
1935        average,
1936        exclusive,
1937        num_classes,
1938        iou_threshold=0.5,
1939        threshold_value=0.5,
1940        ignored_classes=None,
1941    ):
1942        if ignored_classes is None:
1943            ignored_classes = []
1944        self.average = average
1945        self.iou_threshold = iou_threshold
1946        self.threshold = threshold_value
1947        self.exclusive = exclusive
1948        self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes]
1949        super().__init__()
1950
1951    def match(self, lst, ratio, ground):
1952        lst = sorted(lst, key=lambda x: x[2])
1953
1954        def overlap(prop, ground):
1955            s_p, e_p, _ = prop
1956            s_g, e_g, _ = ground
1957            return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g))
1958
1959        cos_map = [-1 for x in range(len(lst))]
1960        count_map = [0 for x in range(len(ground))]
1961
1962        for x in range(len(lst)):
1963            for y in range(len(ground)):
1964                if overlap(lst[x], ground[y]) < ratio:
1965                    continue
1966                if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]):
1967                    continue
1968                cos_map[x] = y
1969            if cos_map[x] != -1:
1970                count_map[cos_map[x]] += 1
1971        positive = sum([(x > 0) for x in count_map])
1972        return cos_map, count_map, positive, [x[2] for x in lst]
1973
1974    def reset(self) -> None:
1975        self.count_map = defaultdict(lambda: [])
1976        self.positive = defaultdict(lambda: 0)
1977        self.cos_map = defaultdict(lambda: [])
1978        self.confidence = defaultdict(lambda: [])
1979
1980    def calc_pr(self, positive, proposal, ground):
1981        if proposal == 0:
1982            return 0, 0
1983        if ground == 0:
1984            return 0, 0
1985        return (1.0 * positive) / proposal, (1.0 * positive) / ground
1986
1987    def calculate(self) -> Union[float, Dict]:
1988        if self.average == "micro":
1989            confidence = []
1990            count_map = []
1991            cos_map = []
1992            positive = sum(self.positive.values())
1993            for key in self.count_map.keys():
1994                confidence += self.confidence[key]
1995                cos_map += list(np.array(self.cos_map[key]) + len(count_map))
1996                count_map += self.count_map[key]
1997            return self.ap(cos_map, count_map, positive, confidence)
1998        results = {
1999            key: self.ap(
2000                self.cos_map[key],
2001                self.count_map[key],
2002                self.positive[key],
2003                self.confidence[key],
2004            )
2005            for key in self.count_map.keys()
2006        }
2007        if self.average == "none":
2008            return results
2009        else:
2010            return float(np.mean(list(results.values())))
2011
2012    def ap(self, cos_map, count_map, positive, confidence):
2013        indices = np.argsort(confidence)
2014        cos_map = list(np.array(cos_map)[indices])
2015        score = 0
2016        number_proposal = len(cos_map)
2017        number_ground = len(count_map)
2018        old_precision, old_recall = self.calc_pr(
2019            positive, number_proposal, number_ground
2020        )
2021
2022        for x in range(len(cos_map)):
2023            number_proposal -= 1
2024            if cos_map[x] == -1:
2025                continue
2026            count_map[cos_map[x]] -= 1
2027            if count_map[cos_map[x]] == 0:
2028                positive -= 1
2029
2030            precision, recall = self.calc_pr(positive, number_proposal, number_ground)
2031            if precision > old_precision:
2032                old_precision = precision
2033            score += old_precision * (old_recall - recall)
2034            old_recall = recall
2035        return score
2036
2037    def _get_intervals(
2038        self, tensor: torch.Tensor, probability: torch.Tensor = None
2039    ) -> Union[Tuple, torch.Tensor]:
2040        """
2041        Get True group beginning and end indices from a boolean tensor and average probability over these intervals
2042        """
2043
2044        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
2045        true_indices = torch.where(output)[0]
2046        starts = torch.tensor(
2047            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
2048        )
2049        ends = torch.tensor(
2050            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
2051        )
2052        confidence = torch.tensor(
2053            [probability[indices == i].mean() for i in true_indices]
2054        )
2055        return torch.stack([starts, ends, confidence]).T
2056
2057    def update(
2058        self,
2059        predicted: torch.Tensor,
2060        target: torch.Tensor,
2061        tags: torch.Tensor,
2062    ) -> None:
2063        predicted = torch.cat(
2064            [
2065                copy(predicted),
2066                -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device),
2067            ],
2068            dim=-1,
2069        )
2070        target = torch.cat(
2071            [
2072                copy(target),
2073                -100 * torch.ones((*target.shape[:-1], 1)).to(target.device),
2074            ],
2075            dim=-1,
2076        )
2077        num_classes = predicted.shape[1]
2078        predicted = predicted.transpose(1, 2).reshape(-1, num_classes)
2079        if self.exclusive:
2080            target = target.flatten()
2081        else:
2082            target = target.transpose(1, 2).reshape(-1, num_classes)
2083        probability = copy(predicted)
2084        if not self.exclusive:
2085            predicted = (predicted > self.threshold).int()
2086        else:
2087            predicted = torch.max(predicted, 1)[1]
2088        for c in self.classes:
2089            if self.exclusive:
2090                predicted_intervals = self._get_intervals(
2091                    predicted == c, probability=probability[:, c]
2092                )
2093                target_intervals = self._get_intervals(
2094                    target == c, probability=probability[:, c]
2095                )
2096            else:
2097                predicted_intervals = self._get_intervals(
2098                    predicted[:, c] == 1, probability=probability[:, c]
2099                )
2100                target_intervals = self._get_intervals(
2101                    target[:, c] == 1, probability=probability[:, c]
2102                )
2103            cos_map, count_map, positive, confidence = self.match(
2104                predicted_intervals, self.iou_threshold, target_intervals
2105            )
2106            cos_map = np.array(cos_map)
2107            cos_map[cos_map != -1] += len(self.count_map[c])
2108            self.cos_map[c] += list(cos_map)
2109            self.count_map[c] += count_map
2110            self.confidence[c] += confidence
2111            self.positive[c] += positive
class PR_AUC(_ClassificationMetric):
470class PR_AUC(_ClassificationMetric):
471    """
472    Area under precision-recall curve (not advised for training)
473    """
474
475    def __init__(
476        self,
477        num_classes: int,
478        ignore_index: int = -100,
479        average: str = "macro",
480        ignored_classes: Set = None,
481        exclusive: bool = False,
482        tag_average: str = "micro",
483        threshold_step: float = 0.1,
484    ):
485        """
486        Initialize the metric
487
488        Parameters
489        ----------
490        ignore_index : int, default -100
491            the class index that indicates ignored samples
492        average: {'macro', 'micro', 'none'}
493            method for averaging across classes
494        num_classes : int, optional
495            number of classes (not necessary if main_class is not None)
496        ignored_classes : set, optional
497            a set of class ids to ignore in calculation
498        exclusive: bool, default True
499            set to False for multi-label classification tasks
500        tag_average: {'micro', 'macro', 'none'}
501            method for averaging across meta tags (if given)
502        threshold_step : float, default 0.1
503            the decision threshold step
504        """
505
506        if exclusive:
507            raise ValueError(
508                "The PR-AUC metric is not implemented for exclusive classification!"
509            )
510        super().__init__(
511            num_classes,
512            ignore_index,
513            average,
514            ignored_classes,
515            exclusive,
516            tag_average=tag_average,
517            threshold_values=list(np.arange(0, 1, threshold_step)),
518        )
519
520    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
521        """
522        Calculate the metric value from true and false positive and negative rates
523        """
524
525        precisions = []
526        recalls = []
527        for k in sorted(self.threshold_values):
528            precisions.append(tp[k] / (tp[k] + fp[k] + 1e-7))
529            recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7))
530        return metrics.auc(x=recalls, y=precisions)

Area under precision-recall curve (not advised for training)

PR_AUC( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = False, tag_average: str = 'micro', threshold_step: float = 0.1)
475    def __init__(
476        self,
477        num_classes: int,
478        ignore_index: int = -100,
479        average: str = "macro",
480        ignored_classes: Set = None,
481        exclusive: bool = False,
482        tag_average: str = "micro",
483        threshold_step: float = 0.1,
484    ):
485        """
486        Initialize the metric
487
488        Parameters
489        ----------
490        ignore_index : int, default -100
491            the class index that indicates ignored samples
492        average: {'macro', 'micro', 'none'}
493            method for averaging across classes
494        num_classes : int, optional
495            number of classes (not necessary if main_class is not None)
496        ignored_classes : set, optional
497            a set of class ids to ignore in calculation
498        exclusive: bool, default True
499            set to False for multi-label classification tasks
500        tag_average: {'micro', 'macro', 'none'}
501            method for averaging across meta tags (if given)
502        threshold_step : float, default 0.1
503            the decision threshold step
504        """
505
506        if exclusive:
507            raise ValueError(
508                "The PR-AUC metric is not implemented for exclusive classification!"
509            )
510        super().__init__(
511            num_classes,
512            ignore_index,
513            average,
514            ignored_classes,
515            exclusive,
516            tag_average=tag_average,
517            threshold_values=list(np.arange(0, 1, threshold_step)),
518        )

Initialize the metric

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_step : float, default 0.1 the decision threshold step

class Precision(_ClassificationMetric):
533class Precision(_ClassificationMetric):
534    """
535    Precision
536    """
537
538    def __init__(
539        self,
540        num_classes: int,
541        ignore_index: int = -100,
542        average: str = "macro",
543        ignored_classes: Set = None,
544        exclusive: bool = True,
545        tag_average: str = "micro",
546        threshold_value: Union[float, List] = None,
547    ):
548        """
549        Parameters
550        ----------
551        ignore_index : int, default -100
552            the class index that indicates ignored samples
553        average: {'macro', 'micro', 'none'}
554            method for averaging across classes
555        num_classes : int, optional
556            number of classes (not necessary if main_class is not None)
557        ignored_classes : set, optional
558            a set of class ids to ignore in calculation
559        exclusive: bool, default True
560            set to False for multi-label classification tasks
561        tag_average: {'micro', 'macro', 'none'}
562            method for averaging across meta tags (if given)
563        threshold_value : float | list, optional
564            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
565            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
566            under the same index
567        """
568
569        if threshold_value is None:
570            threshold_value = 0.5
571        super().__init__(
572            num_classes,
573            ignore_index,
574            average,
575            ignored_classes,
576            exclusive,
577            tag_average=tag_average,
578            threshold_values=[threshold_value],
579        )
580
581    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
582        """
583        Calculate the metric value from true and false positive and negative rates
584        """
585
586        k = self.threshold_values[0]
587        if isinstance(k, list):
588            k = ", ".join(map(str, k))
589        return tp[k] / (tp[k] + fp[k] + 1e-7)

Precision

Precision( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
538    def __init__(
539        self,
540        num_classes: int,
541        ignore_index: int = -100,
542        average: str = "macro",
543        ignored_classes: Set = None,
544        exclusive: bool = True,
545        tag_average: str = "micro",
546        threshold_value: Union[float, List] = None,
547    ):
548        """
549        Parameters
550        ----------
551        ignore_index : int, default -100
552            the class index that indicates ignored samples
553        average: {'macro', 'micro', 'none'}
554            method for averaging across classes
555        num_classes : int, optional
556            number of classes (not necessary if main_class is not None)
557        ignored_classes : set, optional
558            a set of class ids to ignore in calculation
559        exclusive: bool, default True
560            set to False for multi-label classification tasks
561        tag_average: {'micro', 'macro', 'none'}
562            method for averaging across meta tags (if given)
563        threshold_value : float | list, optional
564            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
565            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
566            under the same index
567        """
568
569        if threshold_value is None:
570            threshold_value = 0.5
571        super().__init__(
572            num_classes,
573            ignore_index,
574            average,
575            ignored_classes,
576            exclusive,
577            tag_average=tag_average,
578            threshold_values=[threshold_value],
579        )

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

class SegmentalPrecision(_ClassificationMetric):
592class SegmentalPrecision(_ClassificationMetric):
593    """
594    Segmental precision (not advised for training)
595    """
596
597    segmental = True
598
599    def __init__(
600        self,
601        num_classes: int,
602        ignore_index: int = -100,
603        average: str = "macro",
604        ignored_classes: Set = None,
605        exclusive: bool = True,
606        iou_threshold: float = 0.5,
607        tag_average: str = "micro",
608        threshold_value: Union[float, List] = None,
609    ):
610        """
611        Parameters
612        ----------
613        ignore_index : int, default -100
614            the class index that indicates ignored samples
615        average: {'macro', 'micro', 'none'}
616            method for averaging across classes
617        num_classes : int, optional
618            number of classes (not necessary if main_class is not None)
619        ignored_classes : set, optional
620            a set of class ids to ignore in calculation
621        exclusive: bool, default True
622            set to False for multi-label classification tasks
623        iou_threshold : float, default 0.5
624            if segmental is true, intervals with IoU larger than this threshold are considered correct
625        tag_average: {'micro', 'macro', 'none'}
626            method for averaging across meta tags (if given)
627        threshold_value : float | list, optional
628            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
629            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
630            under the same index
631        """
632
633        if threshold_value is None:
634            threshold_value = 0.5
635        super().__init__(
636            num_classes,
637            ignore_index,
638            average,
639            ignored_classes,
640            exclusive,
641            iou_threshold,
642            tag_average,
643            [threshold_value],
644        )
645
646    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
647        """
648        Calculate the metric value from true and false positive and negative rates
649        """
650
651        k = self.threshold_values[0]
652        if isinstance(k, list):
653            k = ", ".join(map(str, k))
654        return tp[k] / (tp[k] + fp[k] + 1e-7)

Segmental precision (not advised for training)

SegmentalPrecision( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, iou_threshold: float = 0.5, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
599    def __init__(
600        self,
601        num_classes: int,
602        ignore_index: int = -100,
603        average: str = "macro",
604        ignored_classes: Set = None,
605        exclusive: bool = True,
606        iou_threshold: float = 0.5,
607        tag_average: str = "micro",
608        threshold_value: Union[float, List] = None,
609    ):
610        """
611        Parameters
612        ----------
613        ignore_index : int, default -100
614            the class index that indicates ignored samples
615        average: {'macro', 'micro', 'none'}
616            method for averaging across classes
617        num_classes : int, optional
618            number of classes (not necessary if main_class is not None)
619        ignored_classes : set, optional
620            a set of class ids to ignore in calculation
621        exclusive: bool, default True
622            set to False for multi-label classification tasks
623        iou_threshold : float, default 0.5
624            if segmental is true, intervals with IoU larger than this threshold are considered correct
625        tag_average: {'micro', 'macro', 'none'}
626            method for averaging across meta tags (if given)
627        threshold_value : float | list, optional
628            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
629            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
630            under the same index
631        """
632
633        if threshold_value is None:
634            threshold_value = 0.5
635        super().__init__(
636            num_classes,
637            ignore_index,
638            average,
639            ignored_classes,
640            exclusive,
641            iou_threshold,
642            tag_average,
643            [threshold_value],
644        )

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks iou_threshold : float, default 0.5 if segmental is true, intervals with IoU larger than this threshold are considered correct tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

segmental = True

If True, the metric will be calculated over segments; otherwise over frames.

class Recall(_ClassificationMetric):
657class Recall(_ClassificationMetric):
658    """
659    Recall
660    """
661
662    def __init__(
663        self,
664        num_classes: int,
665        ignore_index: int = -100,
666        average: str = "macro",
667        ignored_classes: Set = None,
668        exclusive: bool = True,
669        tag_average: str = "micro",
670        threshold_value: Union[float, List] = None,
671    ):
672        """
673        Parameters
674        ----------
675        ignore_index : int, default -100
676            the class index that indicates ignored samples
677        average: {'macro', 'micro', 'none'}
678            method for averaging across classes
679        num_classes : int, optional
680            number of classes (not necessary if main_class is not None)
681        ignored_classes : set, optional
682            a set of class ids to ignore in calculation
683        exclusive: bool, default True
684            set to False for multi-label classification tasks
685        tag_average: {'micro', 'macro', 'none'}
686            method for averaging across meta tags (if given)
687        threshold_value : float | list, optional
688            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
689            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
690            under the same index
691        """
692
693        if threshold_value is None:
694            threshold_value = 0.5
695        super().__init__(
696            num_classes,
697            ignore_index,
698            average,
699            ignored_classes,
700            exclusive,
701            tag_average=tag_average,
702            threshold_values=[threshold_value],
703        )
704
705    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
706        """
707        Calculate the metric value from true and false positive and negative rates
708        """
709
710        k = self.threshold_values[0]
711        if isinstance(k, list):
712            k = ", ".join(map(str, k))
713        return tp[k] / (tp[k] + fn[k] + 1e-7)

Recall

Recall( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
662    def __init__(
663        self,
664        num_classes: int,
665        ignore_index: int = -100,
666        average: str = "macro",
667        ignored_classes: Set = None,
668        exclusive: bool = True,
669        tag_average: str = "micro",
670        threshold_value: Union[float, List] = None,
671    ):
672        """
673        Parameters
674        ----------
675        ignore_index : int, default -100
676            the class index that indicates ignored samples
677        average: {'macro', 'micro', 'none'}
678            method for averaging across classes
679        num_classes : int, optional
680            number of classes (not necessary if main_class is not None)
681        ignored_classes : set, optional
682            a set of class ids to ignore in calculation
683        exclusive: bool, default True
684            set to False for multi-label classification tasks
685        tag_average: {'micro', 'macro', 'none'}
686            method for averaging across meta tags (if given)
687        threshold_value : float | list, optional
688            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
689            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
690            under the same index
691        """
692
693        if threshold_value is None:
694            threshold_value = 0.5
695        super().__init__(
696            num_classes,
697            ignore_index,
698            average,
699            ignored_classes,
700            exclusive,
701            tag_average=tag_average,
702            threshold_values=[threshold_value],
703        )

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

class SegmentalRecall(_ClassificationMetric):
716class SegmentalRecall(_ClassificationMetric):
717    """
718    Segmental recall (not advised for training)
719    """
720
721    segmental = True
722
723    def __init__(
724        self,
725        num_classes: int,
726        ignore_index: int = -100,
727        average: str = "macro",
728        ignored_classes: Set = None,
729        exclusive: bool = True,
730        iou_threshold: float = 0.5,
731        tag_average: str = "micro",
732        threshold_value: Union[float, List] = None,
733    ):
734        """
735        Parameters
736        ----------
737        ignore_index : int, default -100
738            the class index that indicates ignored samples
739        average: {'macro', 'micro', 'none'}
740            method for averaging across classes
741        num_classes : int, optional
742            number of classes (not necessary if main_class is not None)
743        ignored_classes : set, optional
744            a set of class ids to ignore in calculation
745        exclusive: bool, default True
746            set to False for multi-label classification tasks
747        iou_threshold : float, default 0.5
748            if segmental is true, intervals with IoU larger than this threshold are considered correct
749        tag_average: {'micro', 'macro', 'none'}
750            method for averaging across meta tags (if given)
751        threshold_value : float | list, optional
752            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
753            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
754            under the same index
755        """
756
757        if threshold_value is None:
758            threshold_value = 0.5
759        super().__init__(
760            num_classes,
761            ignore_index,
762            average,
763            ignored_classes,
764            exclusive,
765            iou_threshold,
766            tag_average,
767            [threshold_value],
768        )
769
770    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
771        """
772        Calculate the metric value from true and false positive and negative rates
773        """
774
775        k = self.threshold_values[0]
776        if isinstance(k, list):
777            k = ", ".join(map(str, k))
778        return tp[k] / (tp[k] + fn[k] + 1e-7)

Segmental recall (not advised for training)

SegmentalRecall( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, iou_threshold: float = 0.5, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
723    def __init__(
724        self,
725        num_classes: int,
726        ignore_index: int = -100,
727        average: str = "macro",
728        ignored_classes: Set = None,
729        exclusive: bool = True,
730        iou_threshold: float = 0.5,
731        tag_average: str = "micro",
732        threshold_value: Union[float, List] = None,
733    ):
734        """
735        Parameters
736        ----------
737        ignore_index : int, default -100
738            the class index that indicates ignored samples
739        average: {'macro', 'micro', 'none'}
740            method for averaging across classes
741        num_classes : int, optional
742            number of classes (not necessary if main_class is not None)
743        ignored_classes : set, optional
744            a set of class ids to ignore in calculation
745        exclusive: bool, default True
746            set to False for multi-label classification tasks
747        iou_threshold : float, default 0.5
748            if segmental is true, intervals with IoU larger than this threshold are considered correct
749        tag_average: {'micro', 'macro', 'none'}
750            method for averaging across meta tags (if given)
751        threshold_value : float | list, optional
752            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
753            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
754            under the same index
755        """
756
757        if threshold_value is None:
758            threshold_value = 0.5
759        super().__init__(
760            num_classes,
761            ignore_index,
762            average,
763            ignored_classes,
764            exclusive,
765            iou_threshold,
766            tag_average,
767            [threshold_value],
768        )

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks iou_threshold : float, default 0.5 if segmental is true, intervals with IoU larger than this threshold are considered correct tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

segmental = True

If True, the metric will be calculated over segments; otherwise over frames.

class F1(_ClassificationMetric):
781class F1(_ClassificationMetric):
782    """
783    F1 score
784    """
785
786    def __init__(
787        self,
788        num_classes: int,
789        ignore_index: int = -100,
790        average: str = "macro",
791        ignored_classes: Set = None,
792        exclusive: bool = True,
793        tag_average: str = "micro",
794        threshold_value: Union[float, List] = None,
795        integration_interval: int = 0,
796    ):
797        """
798        Parameters
799        ----------
800        ignore_index : int, default -100
801            the class index that indicates ignored samples
802        average: {'macro', 'micro', 'none'}
803            method for averaging across classes
804        num_classes : int, optional
805            number of classes (not necessary if main_class is not None)
806        ignored_classes : set, optional
807            a set of class ids to ignore in calculation
808        exclusive: bool, default True
809            set to False for multi-label classification tasks
810        tag_average: {'micro', 'macro', 'none'}
811            method for averaging across meta tags (if given)
812        threshold_value : float | list, optional
813            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
814            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
815            under the same index
816        """
817
818        if threshold_value is None:
819            threshold_value = 0.5
820        super().__init__(
821            num_classes,
822            ignore_index,
823            average,
824            ignored_classes,
825            exclusive,
826            tag_average=tag_average,
827            threshold_values=[threshold_value],
828            integration_interval=integration_interval,
829        )
830
831    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
832        """
833        Calculate the metric value from true and false positive and negative rates
834        """
835
836        if self.optimize:
837            scores = []
838            for k in self.threshold_values:
839                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
840                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
841                scores.append(2 * recall * precision / (recall + precision + 1e-7))
842            f1 = max(scores)
843        else:
844            k = self.threshold_values[0]
845            if isinstance(k, list):
846                k = ", ".join(map(str, k))
847            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
848            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
849            f1 = 2 * recall * precision / (recall + precision + 1e-7)
850        return f1

F1 score

F1( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, tag_average: str = 'micro', threshold_value: Union[float, List] = None, integration_interval: int = 0)
786    def __init__(
787        self,
788        num_classes: int,
789        ignore_index: int = -100,
790        average: str = "macro",
791        ignored_classes: Set = None,
792        exclusive: bool = True,
793        tag_average: str = "micro",
794        threshold_value: Union[float, List] = None,
795        integration_interval: int = 0,
796    ):
797        """
798        Parameters
799        ----------
800        ignore_index : int, default -100
801            the class index that indicates ignored samples
802        average: {'macro', 'micro', 'none'}
803            method for averaging across classes
804        num_classes : int, optional
805            number of classes (not necessary if main_class is not None)
806        ignored_classes : set, optional
807            a set of class ids to ignore in calculation
808        exclusive: bool, default True
809            set to False for multi-label classification tasks
810        tag_average: {'micro', 'macro', 'none'}
811            method for averaging across meta tags (if given)
812        threshold_value : float | list, optional
813            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
814            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
815            under the same index
816        """
817
818        if threshold_value is None:
819            threshold_value = 0.5
820        super().__init__(
821            num_classes,
822            ignore_index,
823            average,
824            ignored_classes,
825            exclusive,
826            tag_average=tag_average,
827            threshold_values=[threshold_value],
828            integration_interval=integration_interval,
829        )

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

class SegmentalF1(_ClassificationMetric):
853class SegmentalF1(_ClassificationMetric):
854    """
855    Segmental F1 score (not advised for training)
856    """
857
858    segmental = True
859
860    def __init__(
861        self,
862        num_classes: int,
863        ignore_index: int = -100,
864        average: str = "macro",
865        ignored_classes: Set = None,
866        exclusive: bool = True,
867        iou_threshold: float = 0.5,
868        tag_average: str = "micro",
869        threshold_value: Union[float, List] = None,
870    ):
871        """
872        Parameters
873        ----------
874        ignore_index : int, default -100
875            the class index that indicates ignored samples
876        average: {'macro', 'micro', 'none'}
877            method for averaging across classes
878        num_classes : int, optional
879            number of classes (not necessary if main_class is not None)
880        ignored_classes : set, optional
881            a set of class ids to ignore in calculation
882        exclusive: bool, default True
883            set to False for multi-label classification tasks
884        iou_threshold : float, default 0.5
885            if segmental is true, intervals with IoU larger than this threshold are considered correct
886        tag_average: {'micro', 'macro', 'none'}
887            method for averaging across meta tags (if given)
888        threshold_value : float | list, optional
889            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
890            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
891            under the same index
892        """
893
894        if threshold_value is None:
895            threshold_value = 0.5
896        super().__init__(
897            num_classes,
898            ignore_index,
899            average,
900            ignored_classes,
901            exclusive,
902            iou_threshold,
903            tag_average,
904            [threshold_value],
905        )
906
907    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
908        """
909        Calculate the metric value from true and false positive and negative rates
910        """
911
912        if self.optimize:
913            scores = []
914            for k in self.threshold_values:
915                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
916                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
917                scores.append(2 * recall * precision / (recall + precision + 1e-7))
918            f1 = max(scores)
919        else:
920            k = self.threshold_values[0]
921            if isinstance(k, list):
922                k = ", ".join(map(str, k))
923            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
924            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
925            f1 = 2 * recall * precision / (recall + precision + 1e-7)
926        return f1

Segmental F1 score (not advised for training)

SegmentalF1( num_classes: int, ignore_index: int = -100, average: str = 'macro', ignored_classes: Set = None, exclusive: bool = True, iou_threshold: float = 0.5, tag_average: str = 'micro', threshold_value: Union[float, List] = None)
860    def __init__(
861        self,
862        num_classes: int,
863        ignore_index: int = -100,
864        average: str = "macro",
865        ignored_classes: Set = None,
866        exclusive: bool = True,
867        iou_threshold: float = 0.5,
868        tag_average: str = "micro",
869        threshold_value: Union[float, List] = None,
870    ):
871        """
872        Parameters
873        ----------
874        ignore_index : int, default -100
875            the class index that indicates ignored samples
876        average: {'macro', 'micro', 'none'}
877            method for averaging across classes
878        num_classes : int, optional
879            number of classes (not necessary if main_class is not None)
880        ignored_classes : set, optional
881            a set of class ids to ignore in calculation
882        exclusive: bool, default True
883            set to False for multi-label classification tasks
884        iou_threshold : float, default 0.5
885            if segmental is true, intervals with IoU larger than this threshold are considered correct
886        tag_average: {'micro', 'macro', 'none'}
887            method for averaging across meta tags (if given)
888        threshold_value : float | list, optional
889            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
890            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
891            under the same index
892        """
893
894        if threshold_value is None:
895            threshold_value = 0.5
896        super().__init__(
897            num_classes,
898            ignore_index,
899            average,
900            ignored_classes,
901            exclusive,
902            iou_threshold,
903            tag_average,
904            [threshold_value],
905        )

Parameters

ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks iou_threshold : float, default 0.5 if segmental is true, intervals with IoU larger than this threshold are considered correct tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

segmental = True

If True, the metric will be calculated over segments; otherwise over frames.

class Fbeta(_ClassificationMetric):
 929class Fbeta(_ClassificationMetric):
 930    """
 931    F-beta score
 932    """
 933
 934    def __init__(
 935        self,
 936        beta: float = 1,
 937        ignore_index: int = -100,
 938        average: str = "macro",
 939        num_classes: int = None,
 940        ignored_classes: Set = None,
 941        tag_average: str = "micro",
 942        exclusive: bool = True,
 943        threshold_value: float = 0.5,
 944    ):
 945        """
 946        Parameters
 947        ----------
 948        beta : float, default 1
 949            the beta parameter
 950        ignore_index : int, default -100
 951            the class index that indicates ignored samples
 952        average: {'macro', 'micro', 'none'}
 953            method for averaging across classes
 954        num_classes : int, optional
 955            number of classes (not necessary if main_class is not None)
 956        ignored_classes : set, optional
 957            a set of class ids to ignore in calculation
 958        exclusive: bool, default True
 959            set to False for multi-label classification tasks
 960        tag_average: {'micro', 'macro', 'none'}
 961            method for averaging across meta tags (if given)
 962        threshold_value : float | list, optional
 963            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
 964            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
 965            under the same index
 966        """
 967
 968        if threshold_value is None:
 969            threshold_value = 0.5
 970        self.beta2 = beta**2
 971        super().__init__(
 972            num_classes,
 973            ignore_index,
 974            average,
 975            ignored_classes,
 976            exclusive,
 977            tag_average=tag_average,
 978            threshold_values=[threshold_value],
 979        )
 980
 981    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
 982        """
 983        Calculate the metric value from true and false positive and negative rates
 984        """
 985
 986        if self.optimize:
 987            scores = []
 988            for k in self.threshold_values:
 989                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
 990                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
 991                scores.append(
 992                    (
 993                        (1 + self.beta2)
 994                        * precision
 995                        * recall
 996                        / (self.beta2 * precision + recall + 1e-7)
 997                    )
 998                )
 999            f1 = max(scores)
1000        else:
1001            k = self.threshold_values[0]
1002            if isinstance(k, list):
1003                k = ", ".join(map(str, k))
1004            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1005            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1006            f1 = (
1007                (1 + self.beta2)
1008                * precision
1009                * recall
1010                / (self.beta2 * precision + recall + 1e-7)
1011            )
1012        return f1

F-beta score

Fbeta( beta: float = 1, ignore_index: int = -100, average: str = 'macro', num_classes: int = None, ignored_classes: Set = None, tag_average: str = 'micro', exclusive: bool = True, threshold_value: float = 0.5)
934    def __init__(
935        self,
936        beta: float = 1,
937        ignore_index: int = -100,
938        average: str = "macro",
939        num_classes: int = None,
940        ignored_classes: Set = None,
941        tag_average: str = "micro",
942        exclusive: bool = True,
943        threshold_value: float = 0.5,
944    ):
945        """
946        Parameters
947        ----------
948        beta : float, default 1
949            the beta parameter
950        ignore_index : int, default -100
951            the class index that indicates ignored samples
952        average: {'macro', 'micro', 'none'}
953            method for averaging across classes
954        num_classes : int, optional
955            number of classes (not necessary if main_class is not None)
956        ignored_classes : set, optional
957            a set of class ids to ignore in calculation
958        exclusive: bool, default True
959            set to False for multi-label classification tasks
960        tag_average: {'micro', 'macro', 'none'}
961            method for averaging across meta tags (if given)
962        threshold_value : float | list, optional
963            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
964            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
965            under the same index
966        """
967
968        if threshold_value is None:
969            threshold_value = 0.5
970        self.beta2 = beta**2
971        super().__init__(
972            num_classes,
973            ignore_index,
974            average,
975            ignored_classes,
976            exclusive,
977            tag_average=tag_average,
978            threshold_values=[threshold_value],
979        )

Parameters

beta : float, default 1 the beta parameter ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

class SegmentalFbeta(_ClassificationMetric):
1015class SegmentalFbeta(_ClassificationMetric):
1016    """
1017    Segmental F-beta score (not advised for training)
1018    """
1019
1020    segmental = True
1021
1022    def __init__(
1023        self,
1024        beta: float = 1,
1025        ignore_index: int = -100,
1026        average: str = "macro",
1027        num_classes: int = None,
1028        ignored_classes: Set = None,
1029        iou_threshold: float = 0.5,
1030        tag_average: str = "micro",
1031        exclusive: bool = True,
1032        threshold_value: float = 0.5,
1033    ):
1034        """
1035        Parameters
1036        ----------
1037        beta : float, default 1
1038            the beta parameter
1039        ignore_index : int, default -100
1040            the class index that indicates ignored samples
1041        average: {'macro', 'micro', 'none'}
1042            method for averaging across classes
1043        num_classes : int, optional
1044            number of classes (not necessary if main_class is not None)
1045        ignored_classes : set, optional
1046            a set of class ids to ignore in calculation
1047        exclusive: bool, default True
1048            set to False for multi-label classification tasks
1049        iou_threshold : float, default 0.5
1050            if segmental is true, intervals with IoU larger than this threshold are considered correct
1051        tag_average: {'micro', 'macro', 'none'}
1052            method for averaging across meta tags (if given)
1053        threshold_value : float | list, optional
1054            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
1055            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
1056            under the same index
1057        """
1058
1059        if threshold_value is None:
1060            threshold_value = 0.5
1061        self.beta2 = beta**2
1062        super().__init__(
1063            num_classes,
1064            ignore_index,
1065            average,
1066            ignored_classes,
1067            exclusive,
1068            iou_threshold=iou_threshold,
1069            tag_average=tag_average,
1070            threshold_values=[threshold_value],
1071        )
1072
1073    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1074        """
1075        Calculate the metric value from true and false positive and negative rates
1076        """
1077
1078        if self.optimize:
1079            scores = []
1080            for k in self.threshold_values:
1081                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1082                precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1083                scores.append(
1084                    (
1085                        (1 + self.beta2)
1086                        * precision
1087                        * recall
1088                        / (self.beta2 * precision + recall + 1e-7)
1089                    )
1090                )
1091            f1 = max(scores)
1092        else:
1093            k = self.threshold_values[0]
1094            if isinstance(k, list):
1095                k = ", ".join(map(str, k))
1096            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1097            precision = tp[k] / (tp[k] + fp[k] + 1e-7)
1098            f1 = (
1099                (1 + self.beta2)
1100                * precision
1101                * recall
1102                / (self.beta2 * precision + recall + 1e-7)
1103            )
1104        return f1

Segmental F-beta score (not advised for training)

SegmentalFbeta( beta: float = 1, ignore_index: int = -100, average: str = 'macro', num_classes: int = None, ignored_classes: Set = None, iou_threshold: float = 0.5, tag_average: str = 'micro', exclusive: bool = True, threshold_value: float = 0.5)
1022    def __init__(
1023        self,
1024        beta: float = 1,
1025        ignore_index: int = -100,
1026        average: str = "macro",
1027        num_classes: int = None,
1028        ignored_classes: Set = None,
1029        iou_threshold: float = 0.5,
1030        tag_average: str = "micro",
1031        exclusive: bool = True,
1032        threshold_value: float = 0.5,
1033    ):
1034        """
1035        Parameters
1036        ----------
1037        beta : float, default 1
1038            the beta parameter
1039        ignore_index : int, default -100
1040            the class index that indicates ignored samples
1041        average: {'macro', 'micro', 'none'}
1042            method for averaging across classes
1043        num_classes : int, optional
1044            number of classes (not necessary if main_class is not None)
1045        ignored_classes : set, optional
1046            a set of class ids to ignore in calculation
1047        exclusive: bool, default True
1048            set to False for multi-label classification tasks
1049        iou_threshold : float, default 0.5
1050            if segmental is true, intervals with IoU larger than this threshold are considered correct
1051        tag_average: {'micro', 'macro', 'none'}
1052            method for averaging across meta tags (if given)
1053        threshold_value : float | list, optional
1054            the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
1055            for non-exclusive); if `threshold_value` is a list, every value should correspond to the class
1056            under the same index
1057        """
1058
1059        if threshold_value is None:
1060            threshold_value = 0.5
1061        self.beta2 = beta**2
1062        super().__init__(
1063            num_classes,
1064            ignore_index,
1065            average,
1066            ignored_classes,
1067            exclusive,
1068            iou_threshold=iou_threshold,
1069            tag_average=tag_average,
1070            threshold_values=[threshold_value],
1071        )

Parameters

beta : float, default 1 the beta parameter ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks iou_threshold : float, default 0.5 if segmental is true, intervals with IoU larger than this threshold are considered correct tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_value : float | list, optional the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default for non-exclusive); if threshold_value is a list, every value should correspond to the class under the same index

segmental = True

If True, the metric will be calculated over segments; otherwise over frames.

class SemiSegmentalRecall(_SemiSegmentalMetric):
1357class SemiSegmentalRecall(_SemiSegmentalMetric):
1358    """
1359    Semi-segmental recall (not advised for training)
1360
1361    A metric in-between segmental and frame-wise recall.
1362
1363    This metric follows the following algorithm:
1364    1) smooth over too-short intervals, both in ground truth and in prediction (first remove
1365        groups of zeros shorter than `smooth_interval` and then do the same with groups of ones),
1366    2) add `delta` frames to each ground truth interval at both ends and count the number of predicted
1367        positive frames at the resulting intervals (intersection),
1368    3) calculate the threshold for each interval as
1369        `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where
1370        `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval
1371        before `delta` was added,
1372    4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`),
1373        and otherwise as false negative (`FN`),
1374    5) the final metric value is computed as `TP / (TP + FN)`.
1375    """
1376
1377    def __init__(
1378        self,
1379        num_classes: int,
1380        ignore_index: int = -100,
1381        ignored_classes: Set = None,
1382        exclusive: bool = True,
1383        average: str = "macro",
1384        tag_average: str = "micro",
1385        delta: int = 0,
1386        smooth_interval: int = 0,
1387        iou_threshold_long: float = 0.5,
1388        iou_threshold_short: float = 0.5,
1389        short_length: int = 30,
1390        long_length: int = 300,
1391        threshold_value: Union[float, List] = None,
1392    ) -> None:
1393        """
1394        Parameters
1395        ----------
1396        num_classes : int
1397            the number of classes in the dataset
1398        ignore_index : int, default -100
1399            the ground truth label to ignore
1400        ignored_classes : set, optional
1401            the class indices to ignore in computation
1402        exclusive : bool, default True
1403            `False` for multi-label classification tasks
1404        average : {"macro", "micro", "none"}
1405            the method to average the results over classes
1406        tag_average : {"macro", "micro", "none"}
1407            the method to average the results over meta tags (if given)
1408        delta : int, default 0
1409            the number of frames to add to each ground truth interval before computing the intersection,
1410            see description of the class for details
1411        smooth_interval : int, default 0
1412            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1413            see description of the class for details
1414        iou_threshold_long : float, default 0.5
1415            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1416            see description of the class for details
1417        iou_threshold_short : float, default 0.5
1418            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1419            see description of the class for details
1420        short_length : int, default 30
1421            the threshold number of frames for short intervals that will have an intersection threshold of
1422            `iou_threshold_short`, see description of the class for details
1423        long_length : int, default 300
1424            the threshold number of frames for long intervals that will have an intersection threshold of
1425            `iou_threshold_long`, see description of the class for details
1426        """
1427
1428        super().__init__(
1429            num_classes,
1430            ignore_index,
1431            ignored_classes,
1432            exclusive,
1433            average,
1434            tag_average,
1435            delta,
1436            smooth_interval,
1437            iou_threshold_long,
1438            iou_threshold_short,
1439            short_length,
1440            long_length,
1441            [threshold_value],
1442        )
1443
1444    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1445        """
1446        Calculate the metric value from true and false positive and negative rates
1447        """
1448
1449        k = self.threshold_values[0]
1450        if isinstance(k, list):
1451            k = ", ".join(map(str, k))
1452        return tp[k] / (tp[k] + fn[k] + 1e-7)

Semi-segmental recall (not advised for training)

A metric in-between segmental and frame-wise recall.

This metric follows the following algorithm: 1) smooth over too-short intervals, both in ground truth and in prediction (first remove groups of zeros shorter than smooth_interval and then do the same with groups of ones), 2) add delta frames to each ground truth interval at both ends and count the number of predicted positive frames at the resulting intervals (intersection), 3) calculate the threshold for each interval as t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short)), where a = 2 / (long_length - short_length), b = 1 - a * long_length, x is the length of the interval before delta was added, 4) for each interval, if intersection is higher than t * x, the interval is labeled as true positive (TP), and otherwise as false negative (FN), 5) the final metric value is computed as TP / (TP + FN).

SemiSegmentalRecall( num_classes: int, ignore_index: int = -100, ignored_classes: Set = None, exclusive: bool = True, average: str = 'macro', tag_average: str = 'micro', delta: int = 0, smooth_interval: int = 0, iou_threshold_long: float = 0.5, iou_threshold_short: float = 0.5, short_length: int = 30, long_length: int = 300, threshold_value: Union[float, List] = None)
1377    def __init__(
1378        self,
1379        num_classes: int,
1380        ignore_index: int = -100,
1381        ignored_classes: Set = None,
1382        exclusive: bool = True,
1383        average: str = "macro",
1384        tag_average: str = "micro",
1385        delta: int = 0,
1386        smooth_interval: int = 0,
1387        iou_threshold_long: float = 0.5,
1388        iou_threshold_short: float = 0.5,
1389        short_length: int = 30,
1390        long_length: int = 300,
1391        threshold_value: Union[float, List] = None,
1392    ) -> None:
1393        """
1394        Parameters
1395        ----------
1396        num_classes : int
1397            the number of classes in the dataset
1398        ignore_index : int, default -100
1399            the ground truth label to ignore
1400        ignored_classes : set, optional
1401            the class indices to ignore in computation
1402        exclusive : bool, default True
1403            `False` for multi-label classification tasks
1404        average : {"macro", "micro", "none"}
1405            the method to average the results over classes
1406        tag_average : {"macro", "micro", "none"}
1407            the method to average the results over meta tags (if given)
1408        delta : int, default 0
1409            the number of frames to add to each ground truth interval before computing the intersection,
1410            see description of the class for details
1411        smooth_interval : int, default 0
1412            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1413            see description of the class for details
1414        iou_threshold_long : float, default 0.5
1415            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1416            see description of the class for details
1417        iou_threshold_short : float, default 0.5
1418            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1419            see description of the class for details
1420        short_length : int, default 30
1421            the threshold number of frames for short intervals that will have an intersection threshold of
1422            `iou_threshold_short`, see description of the class for details
1423        long_length : int, default 300
1424            the threshold number of frames for long intervals that will have an intersection threshold of
1425            `iou_threshold_long`, see description of the class for details
1426        """
1427
1428        super().__init__(
1429            num_classes,
1430            ignore_index,
1431            ignored_classes,
1432            exclusive,
1433            average,
1434            tag_average,
1435            delta,
1436            smooth_interval,
1437            iou_threshold_long,
1438            iou_threshold_short,
1439            short_length,
1440            long_length,
1441            [threshold_value],
1442        )

Parameters

num_classes : int the number of classes in the dataset ignore_index : int, default -100 the ground truth label to ignore ignored_classes : set, optional the class indices to ignore in computation exclusive : bool, default True False for multi-label classification tasks average : {"macro", "micro", "none"} the method to average the results over classes tag_average : {"macro", "micro", "none"} the method to average the results over meta tags (if given) delta : int, default 0 the number of frames to add to each ground truth interval before computing the intersection, see description of the class for details smooth_interval : int, default 0 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, see description of the class for details iou_threshold_long : float, default 0.5 the intersection threshold for segments longer than long_length frames (between 0 and 1), see description of the class for details iou_threshold_short : float, default 0.5 the intersection threshold for segments shorter than short_length frames (between 0 and 1), see description of the class for details short_length : int, default 30 the threshold number of frames for short intervals that will have an intersection threshold of iou_threshold_short, see description of the class for details long_length : int, default 300 the threshold number of frames for long intervals that will have an intersection threshold of iou_threshold_long, see description of the class for details

class SemiSegmentalPrecision(_SemiSegmentalMetric):
1455class SemiSegmentalPrecision(_SemiSegmentalMetric):
1456    """
1457    Semi-segmental precision (not advised for training)
1458
1459    A metric in-between segmental and frame-wise precision.
1460
1461    This metric follows the following algorithm:
1462    1) smooth over too-short intervals, both in ground truth and in prediction (first remove
1463        groups of zeros shorter than `smooth_interval` and then do the same with groups of ones),
1464    2) add `delta` frames to each predicted interval at both ends and count the number of ground truth
1465        positive frames at the resulting intervals (intersection),
1466    3) calculate the threshold for each interval as
1467        `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where
1468        `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval
1469        before `delta` was added,
1470    4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`),
1471        and otherwise as false positive (`FP`),
1472    5) the final metric value is computed as `TP / (TP + FP)`.
1473    """
1474
1475    def __init__(
1476        self,
1477        num_classes: int,
1478        ignore_index: int = -100,
1479        ignored_classes: Set = None,
1480        exclusive: bool = True,
1481        average: str = "macro",
1482        tag_average: str = "micro",
1483        delta: int = 0,
1484        smooth_interval: int = 0,
1485        iou_threshold_long: float = 0.5,
1486        iou_threshold_short: float = 0.5,
1487        short_length: int = 30,
1488        long_length: int = 300,
1489        threshold_value: Union[float, List] = None,
1490    ) -> None:
1491        """
1492        Parameters
1493        ----------
1494        num_classes : int
1495            the number of classes in the dataset
1496        ignore_index : int, default -100
1497            the ground truth label to ignore
1498        ignored_classes : set, optional
1499            the class indices to ignore in computation
1500        exclusive : bool, default True
1501            `False` for multi-label classification tasks
1502        average : {"macro", "micro", "none"}
1503            the method to average the results over classes
1504        tag_average : {"macro", "micro", "none"}
1505            the method to average the results over meta tags (if given)
1506        delta : int, default 0
1507            the number of frames to add to each ground truth interval before computing the intersection,
1508            see description of the class for details
1509        smooth_interval : int, default 0
1510            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1511            see description of the class for details
1512        iou_threshold_long : float, default 0.5
1513            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1514            see description of the class for details
1515        iou_threshold_short : float, default 0.5
1516            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1517            see description of the class for details
1518        short_length : int, default 30
1519            the threshold number of frames for short intervals that will have an intersection threshold of
1520            `iou_threshold_short`, see description of the class for details
1521        long_length : int, default 300
1522            the threshold number of frames for long intervals that will have an intersection threshold of
1523            `iou_threshold_long`, see description of the class for details
1524        """
1525
1526        super().__init__(
1527            num_classes,
1528            ignore_index,
1529            ignored_classes,
1530            exclusive,
1531            average,
1532            tag_average,
1533            delta,
1534            smooth_interval,
1535            iou_threshold_long,
1536            iou_threshold_short,
1537            short_length,
1538            long_length,
1539            [threshold_value],
1540        )
1541
1542    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1543        """
1544        Calculate the metric value from true and false positive and negative rates
1545        """
1546
1547        k = self.threshold_values[0]
1548        if isinstance(k, list):
1549            k = ", ".join(map(str, k))
1550        return tn[k] / (tn[k] + fp[k] + 1e-7)

Semi-segmental precision (not advised for training)

A metric in-between segmental and frame-wise precision.

This metric follows the following algorithm: 1) smooth over too-short intervals, both in ground truth and in prediction (first remove groups of zeros shorter than smooth_interval and then do the same with groups of ones), 2) add delta frames to each predicted interval at both ends and count the number of ground truth positive frames at the resulting intervals (intersection), 3) calculate the threshold for each interval as t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short)), where a = 2 / (long_length - short_length), b = 1 - a * long_length, x is the length of the interval before delta was added, 4) for each interval, if intersection is higher than t * x, the interval is labeled as true positive (TP), and otherwise as false positive (FP), 5) the final metric value is computed as TP / (TP + FP).

SemiSegmentalPrecision( num_classes: int, ignore_index: int = -100, ignored_classes: Set = None, exclusive: bool = True, average: str = 'macro', tag_average: str = 'micro', delta: int = 0, smooth_interval: int = 0, iou_threshold_long: float = 0.5, iou_threshold_short: float = 0.5, short_length: int = 30, long_length: int = 300, threshold_value: Union[float, List] = None)
1475    def __init__(
1476        self,
1477        num_classes: int,
1478        ignore_index: int = -100,
1479        ignored_classes: Set = None,
1480        exclusive: bool = True,
1481        average: str = "macro",
1482        tag_average: str = "micro",
1483        delta: int = 0,
1484        smooth_interval: int = 0,
1485        iou_threshold_long: float = 0.5,
1486        iou_threshold_short: float = 0.5,
1487        short_length: int = 30,
1488        long_length: int = 300,
1489        threshold_value: Union[float, List] = None,
1490    ) -> None:
1491        """
1492        Parameters
1493        ----------
1494        num_classes : int
1495            the number of classes in the dataset
1496        ignore_index : int, default -100
1497            the ground truth label to ignore
1498        ignored_classes : set, optional
1499            the class indices to ignore in computation
1500        exclusive : bool, default True
1501            `False` for multi-label classification tasks
1502        average : {"macro", "micro", "none"}
1503            the method to average the results over classes
1504        tag_average : {"macro", "micro", "none"}
1505            the method to average the results over meta tags (if given)
1506        delta : int, default 0
1507            the number of frames to add to each ground truth interval before computing the intersection,
1508            see description of the class for details
1509        smooth_interval : int, default 0
1510            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1511            see description of the class for details
1512        iou_threshold_long : float, default 0.5
1513            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1514            see description of the class for details
1515        iou_threshold_short : float, default 0.5
1516            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1517            see description of the class for details
1518        short_length : int, default 30
1519            the threshold number of frames for short intervals that will have an intersection threshold of
1520            `iou_threshold_short`, see description of the class for details
1521        long_length : int, default 300
1522            the threshold number of frames for long intervals that will have an intersection threshold of
1523            `iou_threshold_long`, see description of the class for details
1524        """
1525
1526        super().__init__(
1527            num_classes,
1528            ignore_index,
1529            ignored_classes,
1530            exclusive,
1531            average,
1532            tag_average,
1533            delta,
1534            smooth_interval,
1535            iou_threshold_long,
1536            iou_threshold_short,
1537            short_length,
1538            long_length,
1539            [threshold_value],
1540        )

Parameters

num_classes : int the number of classes in the dataset ignore_index : int, default -100 the ground truth label to ignore ignored_classes : set, optional the class indices to ignore in computation exclusive : bool, default True False for multi-label classification tasks average : {"macro", "micro", "none"} the method to average the results over classes tag_average : {"macro", "micro", "none"} the method to average the results over meta tags (if given) delta : int, default 0 the number of frames to add to each ground truth interval before computing the intersection, see description of the class for details smooth_interval : int, default 0 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, see description of the class for details iou_threshold_long : float, default 0.5 the intersection threshold for segments longer than long_length frames (between 0 and 1), see description of the class for details iou_threshold_short : float, default 0.5 the intersection threshold for segments shorter than short_length frames (between 0 and 1), see description of the class for details short_length : int, default 30 the threshold number of frames for short intervals that will have an intersection threshold of iou_threshold_short, see description of the class for details long_length : int, default 300 the threshold number of frames for long intervals that will have an intersection threshold of iou_threshold_long, see description of the class for details

class SemiSegmentalF1(_SemiSegmentalMetric):
1553class SemiSegmentalF1(_SemiSegmentalMetric):
1554    """
1555    The F1 score for semi-segmental recall and precision (not advised for training)
1556    """
1557
1558    def __init__(
1559        self,
1560        num_classes: int,
1561        ignore_index: int = -100,
1562        ignored_classes: Set = None,
1563        exclusive: bool = True,
1564        average: str = "macro",
1565        tag_average: str = "micro",
1566        delta: int = 0,
1567        smooth_interval: int = 0,
1568        iou_threshold_long: float = 0.5,
1569        iou_threshold_short: float = 0.5,
1570        short_length: int = 30,
1571        long_length: int = 300,
1572        threshold_value: Union[float, List] = None,
1573    ) -> None:
1574        """
1575        Parameters
1576        ----------
1577        num_classes : int
1578            the number of classes in the dataset
1579        ignore_index : int, default -100
1580            the ground truth label to ignore
1581        ignored_classes : set, optional
1582            the class indices to ignore in computation
1583        exclusive : bool, default True
1584            `False` for multi-label classification tasks
1585        average : {"macro", "micro", "none"}
1586            the method to average the results over classes
1587        tag_average : {"macro", "micro", "none"}
1588            the method to average the results over meta tags (if given)
1589        delta : int, default 0
1590            the number of frames to add to each ground truth interval before computing the intersection,
1591            see description of the class for details
1592        smooth_interval : int, default 0
1593            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1594            see description of the class for details
1595        iou_threshold_long : float, default 0.5
1596            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1597            see description of the class for details
1598        iou_threshold_short : float, default 0.5
1599            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1600            see description of the class for details
1601        short_length : int, default 30
1602            the threshold number of frames for short intervals that will have an intersection threshold of
1603            `iou_threshold_short`, see description of the class for details
1604        long_length : int, default 300
1605            the threshold number of frames for long intervals that will have an intersection threshold of
1606            `iou_threshold_long`, see description of the class for details
1607        """
1608
1609        super().__init__(
1610            num_classes,
1611            ignore_index,
1612            ignored_classes,
1613            exclusive,
1614            average,
1615            tag_average,
1616            delta,
1617            smooth_interval,
1618            iou_threshold_long,
1619            iou_threshold_short,
1620            short_length,
1621            long_length,
1622            [threshold_value],
1623        )
1624
1625    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1626        """
1627        Calculate the metric value from true and false positive and negative rates
1628        """
1629
1630        if self.optimize:
1631            scores = []
1632            for k in self.threshold_values:
1633                recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1634                precision = tn[k] / (tn[k] + fp[k] + 1e-7)
1635                scores.append(2 * recall * precision / (recall + precision + 1e-7))
1636            f1 = max(scores)
1637        else:
1638            k = self.threshold_values[0]
1639            if isinstance(k, list):
1640                k = ", ".join(map(str, k))
1641            recall = tp[k] / (tp[k] + fn[k] + 1e-7)
1642            precision = tn[k] / (tn[k] + fp[k] + 1e-7)
1643            f1 = 2 * recall * precision / (recall + precision + 1e-7)
1644        return f1

The F1 score for semi-segmental recall and precision (not advised for training)

SemiSegmentalF1( num_classes: int, ignore_index: int = -100, ignored_classes: Set = None, exclusive: bool = True, average: str = 'macro', tag_average: str = 'micro', delta: int = 0, smooth_interval: int = 0, iou_threshold_long: float = 0.5, iou_threshold_short: float = 0.5, short_length: int = 30, long_length: int = 300, threshold_value: Union[float, List] = None)
1558    def __init__(
1559        self,
1560        num_classes: int,
1561        ignore_index: int = -100,
1562        ignored_classes: Set = None,
1563        exclusive: bool = True,
1564        average: str = "macro",
1565        tag_average: str = "micro",
1566        delta: int = 0,
1567        smooth_interval: int = 0,
1568        iou_threshold_long: float = 0.5,
1569        iou_threshold_short: float = 0.5,
1570        short_length: int = 30,
1571        long_length: int = 300,
1572        threshold_value: Union[float, List] = None,
1573    ) -> None:
1574        """
1575        Parameters
1576        ----------
1577        num_classes : int
1578            the number of classes in the dataset
1579        ignore_index : int, default -100
1580            the ground truth label to ignore
1581        ignored_classes : set, optional
1582            the class indices to ignore in computation
1583        exclusive : bool, default True
1584            `False` for multi-label classification tasks
1585        average : {"macro", "micro", "none"}
1586            the method to average the results over classes
1587        tag_average : {"macro", "micro", "none"}
1588            the method to average the results over meta tags (if given)
1589        delta : int, default 0
1590            the number of frames to add to each ground truth interval before computing the intersection,
1591            see description of the class for details
1592        smooth_interval : int, default 0
1593            intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
1594            see description of the class for details
1595        iou_threshold_long : float, default 0.5
1596            the intersection threshold for segments longer than `long_length` frames (between 0 and 1),
1597            see description of the class for details
1598        iou_threshold_short : float, default 0.5
1599            the intersection threshold for segments shorter than `short_length` frames (between 0 and 1),
1600            see description of the class for details
1601        short_length : int, default 30
1602            the threshold number of frames for short intervals that will have an intersection threshold of
1603            `iou_threshold_short`, see description of the class for details
1604        long_length : int, default 300
1605            the threshold number of frames for long intervals that will have an intersection threshold of
1606            `iou_threshold_long`, see description of the class for details
1607        """
1608
1609        super().__init__(
1610            num_classes,
1611            ignore_index,
1612            ignored_classes,
1613            exclusive,
1614            average,
1615            tag_average,
1616            delta,
1617            smooth_interval,
1618            iou_threshold_long,
1619            iou_threshold_short,
1620            short_length,
1621            long_length,
1622            [threshold_value],
1623        )

Parameters

num_classes : int the number of classes in the dataset ignore_index : int, default -100 the ground truth label to ignore ignored_classes : set, optional the class indices to ignore in computation exclusive : bool, default True False for multi-label classification tasks average : {"macro", "micro", "none"} the method to average the results over classes tag_average : {"macro", "micro", "none"} the method to average the results over meta tags (if given) delta : int, default 0 the number of frames to add to each ground truth interval before computing the intersection, see description of the class for details smooth_interval : int, default 0 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, see description of the class for details iou_threshold_long : float, default 0.5 the intersection threshold for segments longer than long_length frames (between 0 and 1), see description of the class for details iou_threshold_short : float, default 0.5 the intersection threshold for segments shorter than short_length frames (between 0 and 1), see description of the class for details short_length : int, default 30 the threshold number of frames for short intervals that will have an intersection threshold of iou_threshold_short, see description of the class for details long_length : int, default 300 the threshold number of frames for long intervals that will have an intersection threshold of iou_threshold_long, see description of the class for details

class SemiSegmentalPR_AUC(_SemiSegmentalMetric):
1647class SemiSegmentalPR_AUC(_SemiSegmentalMetric):
1648    """
1649    The area under the precision-recall curve for semi-segmental metrics (not advised for training)
1650    """
1651
1652    def __init__(
1653        self,
1654        num_classes: int,
1655        ignore_index: int = -100,
1656        ignored_classes: Set = None,
1657        exclusive: bool = True,
1658        average: str = "macro",
1659        tag_average: str = "micro",
1660        delta: int = 0,
1661        smooth_interval: int = 0,
1662        iou_threshold_long: float = 0.5,
1663        iou_threshold_short: float = 0.5,
1664        short_length: int = 30,
1665        long_length: int = 300,
1666        threshold_step: float = 0.1,
1667    ) -> None:
1668        super().__init__(
1669            num_classes,
1670            ignore_index,
1671            ignored_classes,
1672            exclusive,
1673            average,
1674            tag_average,
1675            delta,
1676            smooth_interval,
1677            iou_threshold_long,
1678            iou_threshold_short,
1679            short_length,
1680            long_length,
1681            list(np.arange(0, 1, threshold_step)),
1682        )
1683
1684    def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float:
1685        """
1686        Calculate the metric value from true and false positive and negative rates
1687        """
1688
1689        precisions = []
1690        recalls = []
1691        for k in sorted(self.threshold_values):
1692            precisions.append(tn[k] / (tn[k] + fp[k] + 1e-7))
1693            recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7))
1694        return metrics.auc(x=recalls, y=precisions)

The area under the precision-recall curve for semi-segmental metrics (not advised for training)

SemiSegmentalPR_AUC( num_classes: int, ignore_index: int = -100, ignored_classes: Set = None, exclusive: bool = True, average: str = 'macro', tag_average: str = 'micro', delta: int = 0, smooth_interval: int = 0, iou_threshold_long: float = 0.5, iou_threshold_short: float = 0.5, short_length: int = 30, long_length: int = 300, threshold_step: float = 0.1)
1652    def __init__(
1653        self,
1654        num_classes: int,
1655        ignore_index: int = -100,
1656        ignored_classes: Set = None,
1657        exclusive: bool = True,
1658        average: str = "macro",
1659        tag_average: str = "micro",
1660        delta: int = 0,
1661        smooth_interval: int = 0,
1662        iou_threshold_long: float = 0.5,
1663        iou_threshold_short: float = 0.5,
1664        short_length: int = 30,
1665        long_length: int = 300,
1666        threshold_step: float = 0.1,
1667    ) -> None:
1668        super().__init__(
1669            num_classes,
1670            ignore_index,
1671            ignored_classes,
1672            exclusive,
1673            average,
1674            tag_average,
1675            delta,
1676            smooth_interval,
1677            iou_threshold_long,
1678            iou_threshold_short,
1679            short_length,
1680            long_length,
1681            list(np.arange(0, 1, threshold_step)),
1682        )

Parameters

num_classes : int the number of classes in the dataset ignore_index : int, default -100 the ground truth label to ignore ignored_classes : set, optional the class indices to ignore in computation exclusive : bool, default True False for multi-label classification tasks average : {"macro", "micro", "none"} the method to average the results over classes tag_average : {"macro", "micro", "none"} the method to average the results over meta tags (if given) delta : int, default 0 the number of frames to add to each ground truth interval before computing the intersection, see description of the class for details smooth_interval : int, default 0 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, see description of the class for details iou_threshold_long : float, default 0.5 the intersection threshold for segments longer than long_length frames (between 0 and 1), see description of the class for details iou_threshold_short : float, default 0.5 the intersection threshold for segments shorter than short_length frames (between 0 and 1), see description of the class for details short_length : int, default 30 the threshold number of frames for short intervals that will have an intersection threshold of iou_threshold_short, see description of the class for details long_length : int, default 300 the threshold number of frames for long intervals that will have an intersection threshold of iou_threshold_long, see description of the class for details

class Accuracy(dlc2action.metric.base_metric.Metric):
1697class Accuracy(Metric):
1698    """
1699    Accuracy
1700    """
1701
1702    def __init__(self, ignore_index=-100):
1703        """
1704        Parameters
1705        ----------
1706        ignore_index: int
1707            the class index that indicates ignored sample
1708        """
1709
1710        super().__init__()
1711        self.ignore_index = ignore_index
1712
1713    def reset(self) -> None:
1714        """
1715        Reset the intrinsic parameters (at the beginning of an epoch)
1716        """
1717
1718        self.total = 0
1719        self.correct = 0
1720
1721    def calculate(self) -> float:
1722        """
1723        Calculate the metric value
1724
1725        Returns
1726        -------
1727        metric : float
1728            metric value
1729        """
1730
1731        return self.correct / (self.total + 1e-7)
1732
1733    def update(
1734        self,
1735        predicted: torch.Tensor,
1736        target: torch.Tensor,
1737        tags: torch.Tensor = None,
1738    ) -> None:
1739        """
1740        Update the intrinsic parameters (with a batch)
1741
1742        Parameters
1743        ----------
1744        predicted : torch.Tensor
1745            the main prediction tensor generated by the model
1746        ssl_predicted : torch.Tensor
1747            the SSL prediction tensor generated by the model
1748        target : torch.Tensor
1749            the corresponding main target tensor
1750        ssl_target : torch.Tensor
1751            the corresponding SSL target tensor
1752        tags : torch.Tensor
1753            the tensor of meta tags (or `None`, if tags are not given)
1754        """
1755
1756        mask = target != self.ignore_index
1757        self.total += torch.sum(mask)
1758        self.correct += torch.sum((target == predicted)[mask])

Accuracy

Accuracy(ignore_index=-100)
1702    def __init__(self, ignore_index=-100):
1703        """
1704        Parameters
1705        ----------
1706        ignore_index: int
1707            the class index that indicates ignored sample
1708        """
1709
1710        super().__init__()
1711        self.ignore_index = ignore_index

Parameters

ignore_index: int the class index that indicates ignored sample

def reset(self) -> None:
1713    def reset(self) -> None:
1714        """
1715        Reset the intrinsic parameters (at the beginning of an epoch)
1716        """
1717
1718        self.total = 0
1719        self.correct = 0

Reset the intrinsic parameters (at the beginning of an epoch)

def calculate(self) -> float:
1721    def calculate(self) -> float:
1722        """
1723        Calculate the metric value
1724
1725        Returns
1726        -------
1727        metric : float
1728            metric value
1729        """
1730
1731        return self.correct / (self.total + 1e-7)

Calculate the metric value

Returns

metric : float metric value

def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor = None) -> None:
1733    def update(
1734        self,
1735        predicted: torch.Tensor,
1736        target: torch.Tensor,
1737        tags: torch.Tensor = None,
1738    ) -> None:
1739        """
1740        Update the intrinsic parameters (with a batch)
1741
1742        Parameters
1743        ----------
1744        predicted : torch.Tensor
1745            the main prediction tensor generated by the model
1746        ssl_predicted : torch.Tensor
1747            the SSL prediction tensor generated by the model
1748        target : torch.Tensor
1749            the corresponding main target tensor
1750        ssl_target : torch.Tensor
1751            the corresponding SSL target tensor
1752        tags : torch.Tensor
1753            the tensor of meta tags (or `None`, if tags are not given)
1754        """
1755
1756        mask = target != self.ignore_index
1757        self.total += torch.sum(mask)
1758        self.correct += torch.sum((target == predicted)[mask])

Update the intrinsic parameters (with a batch)

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model ssl_predicted : torch.Tensor the SSL prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor ssl_target : torch.Tensor the corresponding SSL target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)

class Count(dlc2action.metric.base_metric.Metric):
1761class Count(Metric):
1762    """
1763    Fraction of samples labeled by the model as a class
1764    """
1765
1766    def __init__(self, classes: Set, exclusive: bool = True):
1767        """
1768        Parameters
1769        ----------
1770        classes : set
1771            the set of classes to count
1772        exclusive: bool, default True
1773            set to False for multi-label classification tasks
1774        """
1775
1776        super().__init__()
1777        self.classes = classes
1778        self.exclusive = exclusive
1779
1780    def reset(self) -> None:
1781        """
1782        Reset the intrinsic parameters (at the beginning of an epoch)
1783        """
1784
1785        self.count = defaultdict(lambda: 0)
1786        self.total = 0
1787
1788    def update(
1789        self,
1790        predicted: torch.Tensor,
1791        target: torch.Tensor,
1792        tags: torch.Tensor,
1793    ) -> None:
1794        """
1795        Update the intrinsic parameters (with a batch)
1796
1797        Parameters
1798        ----------
1799        predicted : torch.Tensor
1800            the main prediction tensor generated by the model
1801        ssl_predicted : torch.Tensor
1802            the SSL prediction tensor generated by the model
1803        target : torch.Tensor
1804            the corresponding main target tensor
1805        ssl_target : torch.Tensor
1806            the corresponding SSL target tensor
1807        tags : torch.Tensor
1808            the tensor of meta tags (or `None`, if tags are not given)
1809        """
1810
1811        if self.exclusive:
1812            for c in self.classes:
1813                self.count[c] += torch.sum(predicted == c)
1814            self.total += torch.numel(predicted)
1815        else:
1816            for c in self.classes:
1817                self.count[c] += torch.sum(predicted[:, c, :] == 1)
1818            self.total += torch.numel(predicted[:, 0, :])
1819
1820    def calculate(self) -> Dict:
1821        """
1822        Calculate the metric (at the end of an epoch)
1823
1824        Returns
1825        -------
1826        result : dict
1827            a dictionary where the keys are class indices and the values are class metric values
1828        """
1829
1830        for c in self.classes:
1831            self.count[c] = self.count[c] / (self.total + 1e-7)
1832        return dict(self.count)

Fraction of samples labeled by the model as a class

Count(classes: Set, exclusive: bool = True)
1766    def __init__(self, classes: Set, exclusive: bool = True):
1767        """
1768        Parameters
1769        ----------
1770        classes : set
1771            the set of classes to count
1772        exclusive: bool, default True
1773            set to False for multi-label classification tasks
1774        """
1775
1776        super().__init__()
1777        self.classes = classes
1778        self.exclusive = exclusive

Parameters

classes : set the set of classes to count exclusive: bool, default True set to False for multi-label classification tasks

def reset(self) -> None:
1780    def reset(self) -> None:
1781        """
1782        Reset the intrinsic parameters (at the beginning of an epoch)
1783        """
1784
1785        self.count = defaultdict(lambda: 0)
1786        self.total = 0

Reset the intrinsic parameters (at the beginning of an epoch)

def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor) -> None:
1788    def update(
1789        self,
1790        predicted: torch.Tensor,
1791        target: torch.Tensor,
1792        tags: torch.Tensor,
1793    ) -> None:
1794        """
1795        Update the intrinsic parameters (with a batch)
1796
1797        Parameters
1798        ----------
1799        predicted : torch.Tensor
1800            the main prediction tensor generated by the model
1801        ssl_predicted : torch.Tensor
1802            the SSL prediction tensor generated by the model
1803        target : torch.Tensor
1804            the corresponding main target tensor
1805        ssl_target : torch.Tensor
1806            the corresponding SSL target tensor
1807        tags : torch.Tensor
1808            the tensor of meta tags (or `None`, if tags are not given)
1809        """
1810
1811        if self.exclusive:
1812            for c in self.classes:
1813                self.count[c] += torch.sum(predicted == c)
1814            self.total += torch.numel(predicted)
1815        else:
1816            for c in self.classes:
1817                self.count[c] += torch.sum(predicted[:, c, :] == 1)
1818            self.total += torch.numel(predicted[:, 0, :])

Update the intrinsic parameters (with a batch)

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model ssl_predicted : torch.Tensor the SSL prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor ssl_target : torch.Tensor the corresponding SSL target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)

def calculate(self) -> Dict:
1820    def calculate(self) -> Dict:
1821        """
1822        Calculate the metric (at the end of an epoch)
1823
1824        Returns
1825        -------
1826        result : dict
1827            a dictionary where the keys are class indices and the values are class metric values
1828        """
1829
1830        for c in self.classes:
1831            self.count[c] = self.count[c] / (self.total + 1e-7)
1832        return dict(self.count)

Calculate the metric (at the end of an epoch)

Returns

result : dict a dictionary where the keys are class indices and the values are class metric values

class EditDistance(dlc2action.metric.base_metric.Metric):
1835class EditDistance(Metric):
1836    """
1837    Edit distance (not advised for training)
1838
1839    Normalized by the length of the sequences
1840    """
1841
1842    def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None:
1843        """
1844        Parameters
1845        ----------
1846        ignore_index : int, default -100
1847            the class index that indicates samples that should be ignored
1848        exclusive : bool, default True
1849            set to False for multi-label classification tasks
1850        """
1851
1852        super().__init__()
1853        self.ignore_index = ignore_index
1854        self.exclusive = exclusive
1855
1856    def reset(self) -> None:
1857        """
1858        Reset the intrinsic parameters (at the beginning of an epoch)
1859        """
1860
1861        self.edit_distance = 0
1862        self.total = 0
1863
1864    def update(
1865        self,
1866        predicted: torch.Tensor,
1867        target: torch.Tensor,
1868        tags: torch.Tensor,
1869    ) -> None:
1870        """
1871        Update the intrinsic parameters (with a batch)
1872
1873        Parameters
1874        ----------
1875        predicted : torch.Tensor
1876            the main prediction tensor generated by the model
1877        ssl_predicted : torch.Tensor
1878            the SSL prediction tensor generated by the model
1879        target : torch.Tensor
1880            the corresponding main target tensor
1881        ssl_target : torch.Tensor
1882            the corresponding SSL target tensor
1883        tags : torch.Tensor
1884            the tensor of meta tags (or `None`, if tags are not given)
1885        """
1886
1887        mask = target != self.ignore_index
1888        self.total += torch.sum(mask)
1889        if self.exclusive:
1890            predicted = predicted[mask].flatten()
1891            target = target[mask].flatten()
1892            self.edit_distance += editdistance.eval(
1893                predicted.detach().cpu().numpy(), target.detach().cpu().numpy()
1894            )
1895        else:
1896            for c in range(target.shape[1]):
1897                predicted_class = predicted[:, c, :][mask[:, c, :]].flatten()
1898                target_class = target[:, c, :][mask[:, c, :]].flatten()
1899                self.edit_distance += editdistance.eval(
1900                    predicted_class.detach().cpu().tolist(),
1901                    target_class.detach().cpu().tolist(),
1902                )
1903
1904    def _is_equal(self, a, b):
1905        """
1906        Compare while ignoring samples marked with ignore_index
1907        """
1908
1909        if self.ignore_index in [a, b] or a == b:
1910            return True
1911        else:
1912            return False
1913
1914    def calculate(self) -> float:
1915        """
1916        Calculate the metric (at the end of an epoch)
1917
1918        Returns
1919        -------
1920        result : float
1921            the metric value
1922        """
1923
1924        return self.edit_distance / (self.total + 1e-7)

Edit distance (not advised for training)

Normalized by the length of the sequences

EditDistance(ignore_index: int = -100, exclusive: bool = True)
1842    def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None:
1843        """
1844        Parameters
1845        ----------
1846        ignore_index : int, default -100
1847            the class index that indicates samples that should be ignored
1848        exclusive : bool, default True
1849            set to False for multi-label classification tasks
1850        """
1851
1852        super().__init__()
1853        self.ignore_index = ignore_index
1854        self.exclusive = exclusive

Parameters

ignore_index : int, default -100 the class index that indicates samples that should be ignored exclusive : bool, default True set to False for multi-label classification tasks

def reset(self) -> None:
1856    def reset(self) -> None:
1857        """
1858        Reset the intrinsic parameters (at the beginning of an epoch)
1859        """
1860
1861        self.edit_distance = 0
1862        self.total = 0

Reset the intrinsic parameters (at the beginning of an epoch)

def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor) -> None:
1864    def update(
1865        self,
1866        predicted: torch.Tensor,
1867        target: torch.Tensor,
1868        tags: torch.Tensor,
1869    ) -> None:
1870        """
1871        Update the intrinsic parameters (with a batch)
1872
1873        Parameters
1874        ----------
1875        predicted : torch.Tensor
1876            the main prediction tensor generated by the model
1877        ssl_predicted : torch.Tensor
1878            the SSL prediction tensor generated by the model
1879        target : torch.Tensor
1880            the corresponding main target tensor
1881        ssl_target : torch.Tensor
1882            the corresponding SSL target tensor
1883        tags : torch.Tensor
1884            the tensor of meta tags (or `None`, if tags are not given)
1885        """
1886
1887        mask = target != self.ignore_index
1888        self.total += torch.sum(mask)
1889        if self.exclusive:
1890            predicted = predicted[mask].flatten()
1891            target = target[mask].flatten()
1892            self.edit_distance += editdistance.eval(
1893                predicted.detach().cpu().numpy(), target.detach().cpu().numpy()
1894            )
1895        else:
1896            for c in range(target.shape[1]):
1897                predicted_class = predicted[:, c, :][mask[:, c, :]].flatten()
1898                target_class = target[:, c, :][mask[:, c, :]].flatten()
1899                self.edit_distance += editdistance.eval(
1900                    predicted_class.detach().cpu().tolist(),
1901                    target_class.detach().cpu().tolist(),
1902                )

Update the intrinsic parameters (with a batch)

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model ssl_predicted : torch.Tensor the SSL prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor ssl_target : torch.Tensor the corresponding SSL target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)

def calculate(self) -> float:
1914    def calculate(self) -> float:
1915        """
1916        Calculate the metric (at the end of an epoch)
1917
1918        Returns
1919        -------
1920        result : float
1921            the metric value
1922        """
1923
1924        return self.edit_distance / (self.total + 1e-7)

Calculate the metric (at the end of an epoch)

Returns

result : float the metric value

class PKU_mAP(dlc2action.metric.base_metric.Metric):
1927class PKU_mAP(Metric):
1928    """
1929    Mean average precision (segmental) (not advised for training)
1930    """
1931
1932    needs_raw_data = True
1933
1934    def __init__(
1935        self,
1936        average,
1937        exclusive,
1938        num_classes,
1939        iou_threshold=0.5,
1940        threshold_value=0.5,
1941        ignored_classes=None,
1942    ):
1943        if ignored_classes is None:
1944            ignored_classes = []
1945        self.average = average
1946        self.iou_threshold = iou_threshold
1947        self.threshold = threshold_value
1948        self.exclusive = exclusive
1949        self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes]
1950        super().__init__()
1951
1952    def match(self, lst, ratio, ground):
1953        lst = sorted(lst, key=lambda x: x[2])
1954
1955        def overlap(prop, ground):
1956            s_p, e_p, _ = prop
1957            s_g, e_g, _ = ground
1958            return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g))
1959
1960        cos_map = [-1 for x in range(len(lst))]
1961        count_map = [0 for x in range(len(ground))]
1962
1963        for x in range(len(lst)):
1964            for y in range(len(ground)):
1965                if overlap(lst[x], ground[y]) < ratio:
1966                    continue
1967                if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]):
1968                    continue
1969                cos_map[x] = y
1970            if cos_map[x] != -1:
1971                count_map[cos_map[x]] += 1
1972        positive = sum([(x > 0) for x in count_map])
1973        return cos_map, count_map, positive, [x[2] for x in lst]
1974
1975    def reset(self) -> None:
1976        self.count_map = defaultdict(lambda: [])
1977        self.positive = defaultdict(lambda: 0)
1978        self.cos_map = defaultdict(lambda: [])
1979        self.confidence = defaultdict(lambda: [])
1980
1981    def calc_pr(self, positive, proposal, ground):
1982        if proposal == 0:
1983            return 0, 0
1984        if ground == 0:
1985            return 0, 0
1986        return (1.0 * positive) / proposal, (1.0 * positive) / ground
1987
1988    def calculate(self) -> Union[float, Dict]:
1989        if self.average == "micro":
1990            confidence = []
1991            count_map = []
1992            cos_map = []
1993            positive = sum(self.positive.values())
1994            for key in self.count_map.keys():
1995                confidence += self.confidence[key]
1996                cos_map += list(np.array(self.cos_map[key]) + len(count_map))
1997                count_map += self.count_map[key]
1998            return self.ap(cos_map, count_map, positive, confidence)
1999        results = {
2000            key: self.ap(
2001                self.cos_map[key],
2002                self.count_map[key],
2003                self.positive[key],
2004                self.confidence[key],
2005            )
2006            for key in self.count_map.keys()
2007        }
2008        if self.average == "none":
2009            return results
2010        else:
2011            return float(np.mean(list(results.values())))
2012
2013    def ap(self, cos_map, count_map, positive, confidence):
2014        indices = np.argsort(confidence)
2015        cos_map = list(np.array(cos_map)[indices])
2016        score = 0
2017        number_proposal = len(cos_map)
2018        number_ground = len(count_map)
2019        old_precision, old_recall = self.calc_pr(
2020            positive, number_proposal, number_ground
2021        )
2022
2023        for x in range(len(cos_map)):
2024            number_proposal -= 1
2025            if cos_map[x] == -1:
2026                continue
2027            count_map[cos_map[x]] -= 1
2028            if count_map[cos_map[x]] == 0:
2029                positive -= 1
2030
2031            precision, recall = self.calc_pr(positive, number_proposal, number_ground)
2032            if precision > old_precision:
2033                old_precision = precision
2034            score += old_precision * (old_recall - recall)
2035            old_recall = recall
2036        return score
2037
2038    def _get_intervals(
2039        self, tensor: torch.Tensor, probability: torch.Tensor = None
2040    ) -> Union[Tuple, torch.Tensor]:
2041        """
2042        Get True group beginning and end indices from a boolean tensor and average probability over these intervals
2043        """
2044
2045        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
2046        true_indices = torch.where(output)[0]
2047        starts = torch.tensor(
2048            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
2049        )
2050        ends = torch.tensor(
2051            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
2052        )
2053        confidence = torch.tensor(
2054            [probability[indices == i].mean() for i in true_indices]
2055        )
2056        return torch.stack([starts, ends, confidence]).T
2057
2058    def update(
2059        self,
2060        predicted: torch.Tensor,
2061        target: torch.Tensor,
2062        tags: torch.Tensor,
2063    ) -> None:
2064        predicted = torch.cat(
2065            [
2066                copy(predicted),
2067                -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device),
2068            ],
2069            dim=-1,
2070        )
2071        target = torch.cat(
2072            [
2073                copy(target),
2074                -100 * torch.ones((*target.shape[:-1], 1)).to(target.device),
2075            ],
2076            dim=-1,
2077        )
2078        num_classes = predicted.shape[1]
2079        predicted = predicted.transpose(1, 2).reshape(-1, num_classes)
2080        if self.exclusive:
2081            target = target.flatten()
2082        else:
2083            target = target.transpose(1, 2).reshape(-1, num_classes)
2084        probability = copy(predicted)
2085        if not self.exclusive:
2086            predicted = (predicted > self.threshold).int()
2087        else:
2088            predicted = torch.max(predicted, 1)[1]
2089        for c in self.classes:
2090            if self.exclusive:
2091                predicted_intervals = self._get_intervals(
2092                    predicted == c, probability=probability[:, c]
2093                )
2094                target_intervals = self._get_intervals(
2095                    target == c, probability=probability[:, c]
2096                )
2097            else:
2098                predicted_intervals = self._get_intervals(
2099                    predicted[:, c] == 1, probability=probability[:, c]
2100                )
2101                target_intervals = self._get_intervals(
2102                    target[:, c] == 1, probability=probability[:, c]
2103                )
2104            cos_map, count_map, positive, confidence = self.match(
2105                predicted_intervals, self.iou_threshold, target_intervals
2106            )
2107            cos_map = np.array(cos_map)
2108            cos_map[cos_map != -1] += len(self.count_map[c])
2109            self.cos_map[c] += list(cos_map)
2110            self.count_map[c] += count_map
2111            self.confidence[c] += confidence
2112            self.positive[c] += positive

Mean average precision (segmental) (not advised for training)

PKU_mAP( average, exclusive, num_classes, iou_threshold=0.5, threshold_value=0.5, ignored_classes=None)
1934    def __init__(
1935        self,
1936        average,
1937        exclusive,
1938        num_classes,
1939        iou_threshold=0.5,
1940        threshold_value=0.5,
1941        ignored_classes=None,
1942    ):
1943        if ignored_classes is None:
1944            ignored_classes = []
1945        self.average = average
1946        self.iou_threshold = iou_threshold
1947        self.threshold = threshold_value
1948        self.exclusive = exclusive
1949        self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes]
1950        super().__init__()
needs_raw_data = True

If True, dlc2action.task.universal_task.Task will pass raw data to the metric (only primary predict function applied). Otherwise it will pass a prediction for the classes.

def match(self, lst, ratio, ground)
1952    def match(self, lst, ratio, ground):
1953        lst = sorted(lst, key=lambda x: x[2])
1954
1955        def overlap(prop, ground):
1956            s_p, e_p, _ = prop
1957            s_g, e_g, _ = ground
1958            return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g))
1959
1960        cos_map = [-1 for x in range(len(lst))]
1961        count_map = [0 for x in range(len(ground))]
1962
1963        for x in range(len(lst)):
1964            for y in range(len(ground)):
1965                if overlap(lst[x], ground[y]) < ratio:
1966                    continue
1967                if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]):
1968                    continue
1969                cos_map[x] = y
1970            if cos_map[x] != -1:
1971                count_map[cos_map[x]] += 1
1972        positive = sum([(x > 0) for x in count_map])
1973        return cos_map, count_map, positive, [x[2] for x in lst]
def reset(self) -> None:
1975    def reset(self) -> None:
1976        self.count_map = defaultdict(lambda: [])
1977        self.positive = defaultdict(lambda: 0)
1978        self.cos_map = defaultdict(lambda: [])
1979        self.confidence = defaultdict(lambda: [])

Reset the intrinsic parameters (at the beginning of an epoch)

def calc_pr(self, positive, proposal, ground)
1981    def calc_pr(self, positive, proposal, ground):
1982        if proposal == 0:
1983            return 0, 0
1984        if ground == 0:
1985            return 0, 0
1986        return (1.0 * positive) / proposal, (1.0 * positive) / ground
def calculate(self) -> Union[float, Dict]:
1988    def calculate(self) -> Union[float, Dict]:
1989        if self.average == "micro":
1990            confidence = []
1991            count_map = []
1992            cos_map = []
1993            positive = sum(self.positive.values())
1994            for key in self.count_map.keys():
1995                confidence += self.confidence[key]
1996                cos_map += list(np.array(self.cos_map[key]) + len(count_map))
1997                count_map += self.count_map[key]
1998            return self.ap(cos_map, count_map, positive, confidence)
1999        results = {
2000            key: self.ap(
2001                self.cos_map[key],
2002                self.count_map[key],
2003                self.positive[key],
2004                self.confidence[key],
2005            )
2006            for key in self.count_map.keys()
2007        }
2008        if self.average == "none":
2009            return results
2010        else:
2011            return float(np.mean(list(results.values())))

Calculate the metric (at the end of an epoch)

Returns

result : float | dict either the single value of the metric or a dictionary where the keys are class indices and the values are class metric values

def ap(self, cos_map, count_map, positive, confidence)
2013    def ap(self, cos_map, count_map, positive, confidence):
2014        indices = np.argsort(confidence)
2015        cos_map = list(np.array(cos_map)[indices])
2016        score = 0
2017        number_proposal = len(cos_map)
2018        number_ground = len(count_map)
2019        old_precision, old_recall = self.calc_pr(
2020            positive, number_proposal, number_ground
2021        )
2022
2023        for x in range(len(cos_map)):
2024            number_proposal -= 1
2025            if cos_map[x] == -1:
2026                continue
2027            count_map[cos_map[x]] -= 1
2028            if count_map[cos_map[x]] == 0:
2029                positive -= 1
2030
2031            precision, recall = self.calc_pr(positive, number_proposal, number_ground)
2032            if precision > old_precision:
2033                old_precision = precision
2034            score += old_precision * (old_recall - recall)
2035            old_recall = recall
2036        return score
def update( self, predicted: torch.Tensor, target: torch.Tensor, tags: torch.Tensor) -> None:
2058    def update(
2059        self,
2060        predicted: torch.Tensor,
2061        target: torch.Tensor,
2062        tags: torch.Tensor,
2063    ) -> None:
2064        predicted = torch.cat(
2065            [
2066                copy(predicted),
2067                -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device),
2068            ],
2069            dim=-1,
2070        )
2071        target = torch.cat(
2072            [
2073                copy(target),
2074                -100 * torch.ones((*target.shape[:-1], 1)).to(target.device),
2075            ],
2076            dim=-1,
2077        )
2078        num_classes = predicted.shape[1]
2079        predicted = predicted.transpose(1, 2).reshape(-1, num_classes)
2080        if self.exclusive:
2081            target = target.flatten()
2082        else:
2083            target = target.transpose(1, 2).reshape(-1, num_classes)
2084        probability = copy(predicted)
2085        if not self.exclusive:
2086            predicted = (predicted > self.threshold).int()
2087        else:
2088            predicted = torch.max(predicted, 1)[1]
2089        for c in self.classes:
2090            if self.exclusive:
2091                predicted_intervals = self._get_intervals(
2092                    predicted == c, probability=probability[:, c]
2093                )
2094                target_intervals = self._get_intervals(
2095                    target == c, probability=probability[:, c]
2096                )
2097            else:
2098                predicted_intervals = self._get_intervals(
2099                    predicted[:, c] == 1, probability=probability[:, c]
2100                )
2101                target_intervals = self._get_intervals(
2102                    target[:, c] == 1, probability=probability[:, c]
2103                )
2104            cos_map, count_map, positive, confidence = self.match(
2105                predicted_intervals, self.iou_threshold, target_intervals
2106            )
2107            cos_map = np.array(cos_map)
2108            cos_map[cos_map != -1] += len(self.count_map[c])
2109            self.cos_map[c] += list(cos_map)
2110            self.count_map[c] += count_map
2111            self.confidence[c] += confidence
2112            self.positive[c] += positive

Update the intrinsic parameters (with a batch)

Parameters

predicted : torch.Tensor the main prediction tensor generated by the model ssl_predicted : torch.Tensor the SSL prediction tensor generated by the model target : torch.Tensor the corresponding main target tensor ssl_target : torch.Tensor the corresponding SSL target tensor tags : torch.Tensor the tensor of meta tags (or None, if tags are not given)