dlc2action.metric.metrics
Implementations of dlc2action.metric.base_metric.Metric.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Implementations of `dlc2action.metric.base_metric.Metric`.""" 8 9import warnings 10from abc import abstractmethod 11from collections import defaultdict 12from copy import copy, deepcopy 13from typing import Any, Dict, List, Set, Tuple, Union 14 15import editdistance 16import numpy as np 17import torch 18from dlc2action.metric.base_metric import Metric 19from sklearn import metrics 20 21 22class _ClassificationMetric(Metric): 23 """The base class for all metric that are calculated from true and false negative and positive rates.""" 24 25 needs_raw_data = True 26 segmental = False 27 """ 28 If `True`, the metric will be calculated over segments; otherwise over frames. 29 """ 30 31 def __init__( 32 self, 33 num_classes: int, 34 ignore_index: int = -100, 35 average: str = "macro", 36 ignored_classes: Set = None, 37 exclusive: bool = True, 38 iou_threshold: float = 0.5, 39 tag_average: str = "micro", 40 threshold_values: List = None, 41 integration_interval: int = 0, 42 ): 43 """Initialize the metric. 44 45 Parameters 46 ---------- 47 ignore_index : int, default -100 48 the class index that indicates ignored samples 49 average: {'macro', 'micro', 'none'} 50 method for averaging across classes 51 num_classes : int, optional 52 number of classes (not necessary if main_class is not None) 53 ignored_classes : set, optional 54 a set of class ids to ignore in calculation 55 exclusive: bool, default True 56 set to False for multi-label classification tasks 57 iou_threshold : float, default 0.5 58 if segmental is true, intervals with IoU larger than this threshold are considered correct 59 tag_average: {'micro', 'macro', 'none'} 60 method for averaging across meta tags (if given) 61 threshold_values : List, optional 62 a list of float values between 0 and 1 that will be used as decision thresholds 63 integration_interval : int, default 0 64 if not 0, the metric will be calculated over a window of this size (in frames) around the current frame 65 66 """ 67 super().__init__() 68 self.ignore_index = ignore_index 69 self.average = average 70 self.tag_average = tag_average 71 self.tags = set() 72 self.exclusive = exclusive 73 self.threshold = iou_threshold 74 self.integration_interval = integration_interval 75 self.classes = list(range(int(num_classes))) 76 self.optimize = False 77 if threshold_values is None: 78 threshold_values = [0.5] 79 elif threshold_values == ["optimize"]: 80 threshold_values = list(np.arange(0.25, 0.75, 0.05)) 81 self.optimize = True 82 if self.exclusive and threshold_values != [0.5]: 83 raise ValueError( 84 "Cannot set threshold values for exclusive classification!" 85 ) 86 self.threshold_values = threshold_values 87 88 if ignored_classes is not None: 89 for c in ignored_classes: 90 if c in self.classes: 91 self.classes.remove(c) 92 if len(self.classes) == 0: 93 warnings.warn("No classes are followed!") 94 95 def reset(self) -> None: 96 """Reset the internal state (at the beginning of an epoch).""" 97 self.tags = set() 98 self.tp = defaultdict(lambda: defaultdict(lambda: 0)) 99 self.fp = defaultdict(lambda: defaultdict(lambda: 0)) 100 self.fn = defaultdict(lambda: defaultdict(lambda: 0)) 101 self.tn = defaultdict(lambda: defaultdict(lambda: 0)) 102 103 def update( 104 self, 105 predicted: torch.Tensor, 106 target: torch.Tensor, 107 tags: torch.Tensor, 108 ) -> None: 109 """Update the internal state (with a batch). 110 111 Parameters 112 ---------- 113 predicted : torch.Tensor 114 the main prediction tensor generated by the model 115 ssl_predicted : torch.Tensor 116 the SSL prediction tensor generated by the model 117 target : torch.Tensor 118 the corresponding main target tensor 119 ssl_target : torch.Tensor 120 the corresponding SSL target tensor 121 tags : torch.Tensor 122 the tensor of meta tags (or `None`, if tags are not given) 123 124 """ 125 if self.segmental: 126 self._update_segmental(predicted, target, tags) 127 else: 128 self._update_normal(predicted, target, tags) 129 130 def _key(self, tag: torch.Tensor, c: int) -> str: 131 """Get a key for the intermediate value dictionaries from tag and class indices.""" 132 if tag is None or self.tag_average == "micro": 133 return c 134 else: 135 return f"tag{int(tag)}_{c}" 136 137 def _update_normal( 138 self, 139 predicted: torch.Tensor, 140 target: torch.Tensor, 141 tags: torch.Tensor, 142 ) -> None: 143 """Update the internal state (with a batch), calculating over frames.""" 144 if self.integration_interval != 0: 145 pr = 0 146 denom = torch.ones((1, 1, predicted.shape[-1])) * ( 147 2 * self.integration_interval + 1 148 ) 149 for i in range(self.integration_interval): 150 pr += torch.cat( 151 [ 152 torch.zeros((predicted.shape[0], predicted.shape[1], i + 1)), 153 predicted[:, :, i + 1 :], 154 ], 155 dim=-1, 156 ) 157 pr += torch.cat( 158 [ 159 predicted[:, :, : -i - 1], 160 torch.zeros((predicted.shape[0], predicted.shape[1], i + 1)), 161 ], 162 dim=-1, 163 ) 164 denom[:, :, i] = i + 1 165 denom[:, :, -i] = i + 1 166 predicted = pr / denom 167 if self.exclusive: 168 predicted = torch.max(predicted, 1)[1] 169 if self.tag_average == "micro" or tags is None or tags[0] is None: 170 tag_set = {None} 171 else: 172 tag_set = set(tags) 173 self.tags.update(tag_set) 174 for thr in self.threshold_values: 175 for t in tag_set: 176 if t is None: 177 predicted_t = predicted 178 target_t = target 179 else: 180 predicted_t = predicted[tags == t] 181 target_t = target[tags == t] 182 for c in self.classes: 183 if self.exclusive: 184 pos = (predicted_t == c) * (target_t != self.ignore_index) 185 neg = (predicted_t != c) * (target_t != self.ignore_index) 186 self.tp[self._key(t, c)][thr] += ((target_t == c) * pos).sum() 187 self.fp[self._key(t, c)][thr] += ((target_t != c) * pos).sum() 188 self.fn[self._key(t, c)][thr] += ((target_t == c) * neg).sum() 189 self.tn[self._key(t, c)][thr] += ((target_t != c) * neg).sum() 190 else: 191 if isinstance(thr, list): 192 key = ", ".join(map(str, thr)) 193 for i, tt in enumerate(thr): 194 predicted_t[:, i, :] = (predicted_t[:, i, :] > tt).int() 195 else: 196 key = thr 197 predicted_t = (predicted_t > thr).int() 198 pos = predicted_t[:, c, :] * ( 199 target_t[:, c, :] != self.ignore_index 200 ) 201 neg = (predicted_t[:, c, :] != 1) * ( 202 target_t[:, c, :] != self.ignore_index 203 ) 204 self.tp[self._key(t, c)][key] += (target_t[:, c, :] * pos).sum() 205 self.fp[self._key(t, c)][key] += ( 206 (target_t[:, c, :] != 1) * pos 207 ).sum() 208 self.fn[self._key(t, c)][key] += (target_t[:, c, :] * neg).sum() 209 self.tn[self._key(t, c)][key] += ( 210 (target_t[:, c, :] != 1) * neg 211 ).sum() 212 213 def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor: 214 """Get a list of `True` group beginning and end indices from a boolean tensor.""" 215 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 216 true_indices = torch.where(output)[0] 217 starts = torch.tensor( 218 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 219 ) 220 ends = torch.tensor( 221 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 222 ) 223 return torch.stack([starts, ends]).T 224 225 def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor: 226 """Get rid of jittering in a non-exclusive classification tensor. 227 228 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 229 `smooth_interval`. 230 231 """ 232 for c in self.classes: 233 intervals = self._get_intervals(tensor[:, c] == 0) 234 interval_lengths = torch.tensor( 235 [interval[1] - interval[0] for interval in intervals] 236 ) 237 short_intervals = intervals[interval_lengths <= smooth_interval] 238 for start, end in short_intervals: 239 tensor[start:end, c] = 1 240 intervals = self._get_intervals(tensor[:, c] == 1) 241 interval_lengths = torch.tensor( 242 [interval[1] - interval[0] for interval in intervals] 243 ) 244 short_intervals = intervals[interval_lengths <= smooth_interval] 245 for start, end in short_intervals: 246 tensor[start:end, c] = 0 247 return tensor 248 249 def _sigmoid_threshold_function( 250 self, 251 low_threshold: float, 252 high_threshold: float, 253 low_length: int, 254 high_length: int, 255 ): 256 """Generate a sigmoid threshold function. 257 258 The resulting function outputs an intersection threshold given the length of the interval. 259 260 """ 261 a = 2 / (high_length - low_length) 262 b = 1 - a * high_length 263 return lambda x: low_threshold + torch.sigmoid(4 * (a * x + b)) * ( 264 high_threshold - low_threshold 265 ) 266 267 def _update_segmental( 268 self, 269 predicted: torch.tensor, 270 target: torch.tensor, 271 tags: torch.Tensor, 272 ) -> None: 273 """Update the internal state (with a batch), calculating over segments.""" 274 if self.exclusive: 275 predicted = torch.max(predicted, 1)[1] 276 predicted = torch.cat( 277 [ 278 copy(predicted), 279 -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device), 280 ], 281 dim=-1, 282 ) 283 target = torch.cat( 284 [ 285 copy(target), 286 -100 * torch.ones((*target.shape[:-1], 1)).to(target.device), 287 ], 288 dim=-1, 289 ) 290 if self.exclusive: 291 predicted = predicted.flatten() 292 target = target.flatten() 293 else: 294 num_classes = predicted.shape[1] 295 predicted = predicted.transpose(1, 2).reshape(-1, num_classes) 296 target = target.transpose(1, 2).reshape(-1, num_classes) 297 if self.tag_average == "micro" or tags is None or tags[0] is None: 298 tag_set = {None} 299 else: 300 tag_set = set(tags) 301 self.tags.update(tag_set) 302 for thr in self.threshold_values: 303 key = thr 304 for t in tag_set: 305 if t is None: 306 predicted_t = predicted 307 target_t = target 308 else: 309 predicted_t = predicted[tags == t] 310 target_t = target[tags == t] 311 if not self.exclusive: 312 if isinstance(thr, list): 313 for i, tt in enumerate(thr): 314 predicted_t[i, :] = (predicted_t[i, :] > tt).int() 315 key = ", ".join(map(str, thr)) 316 else: 317 predicted_t = (predicted_t > thr).int() 318 for c in self.classes: 319 if self.exclusive: 320 predicted_intervals = self._get_intervals(predicted_t == c) 321 target_intervals = self._get_intervals(target_t == c) 322 else: 323 predicted_intervals = self._get_intervals( 324 predicted_t[:, c] == 1 325 ) 326 target_intervals = self._get_intervals(target_t[:, c] == 1) 327 true_used = torch.zeros(target_intervals.shape[0]) 328 for interval in predicted_intervals: 329 if len(target_intervals) > 0: 330 # Compute IoU against all others 331 intersection = torch.minimum( 332 interval[1], target_intervals[:, 1] 333 ) - torch.maximum(interval[0], target_intervals[:, 0]) 334 union = torch.maximum( 335 interval[1], target_intervals[:, 1] 336 ) - torch.minimum(interval[0], target_intervals[:, 0]) 337 IoU = intersection / union 338 339 # Get the best scoring segment 340 idx = IoU.argmax() 341 342 # If the IoU is high enough and the true segment isn't already used 343 # Then it is a true positive. Otherwise is it a false positive. 344 if IoU[idx] >= self.threshold and not true_used[idx]: 345 self.tp[self._key(t, c)][key] += 1 346 true_used[idx] = 1 347 else: 348 self.fp[self._key(t, c)][key] += 1 349 else: 350 self.fp[self._key(t, c)][key] += 1 351 self.fn[self._key(t, c)][key] += len(true_used) - torch.sum( 352 true_used 353 ) 354 355 def calculate(self) -> Union[float, Dict]: 356 """Calculate the metric (at the end of an epoch). 357 358 Returns 359 ------- 360 result : float | dict 361 either the single value of the metric or a dictionary where the keys are class indices and the values 362 are class metric values 363 364 """ 365 if self.tag_average == "micro" or self.tags == {None}: 366 self.tags = {None} 367 result = {} 368 self.tags = sorted(list(self.tags)) 369 for tag in self.tags: 370 if self.average == "macro": 371 metric_list = [] 372 for c in self.classes: 373 metric_list.append( 374 self._calculate_metric( 375 self.tp[self._key(tag, c)], 376 self.fp[self._key(tag, c)], 377 self.fn[self._key(tag, c)], 378 self.tn[self._key(tag, c)], 379 ) 380 ) 381 if tag is not None: 382 tag = int(tag) 383 result[f"tag{tag}"] = sum(metric_list) / len(metric_list) 384 elif self.average == "micro": 385 # tp = sum([self.tp[self._key(tag, c)] for c in self.classes]) 386 # fn = sum([self.fn[self._key(tag, c)] for c in self.classes]) 387 # fp = sum([self.fp[self._key(tag, c)] for c in self.classes]) 388 # tn = sum([self.tn[self._key(tag, c)] for c in self.classes]) 389 # if tag is not None: 390 # tag = int(tag) 391 # result[f"tag{tag}"] = self._calculate_metric(tp, fp, fn, tn) 392 metric_list = [] 393 for c in self.classes: 394 metric_list.append( 395 self._calculate_metric( 396 self.tp[self._key(tag, c)], 397 self.fp[self._key(tag, c)], 398 self.fn[self._key(tag, c)], 399 self.tn[self._key(tag, c)], 400 ) 401 ) 402 403 if tag is not None: 404 tag = int(tag) 405 result[f"tag{tag}"] = metric_list 406 elif self.average == "none": 407 metric_dict = {} 408 for c in self.classes: 409 metric_dict[self._key(tag, c)] = self._calculate_metric( 410 self.tp[self._key(tag, c)], 411 self.fp[self._key(tag, c)], 412 self.fn[self._key(tag, c)], 413 self.tn[self._key(tag, c)], 414 ) 415 result.update(metric_dict) 416 else: 417 raise ValueError( 418 f"The {self.average} averaging method is not available, please choose from " 419 f'["none", "micro", "macro"]' 420 ) 421 if len(self.tags) == 1 and self.average != "none": 422 tag = self.tags[0] 423 if tag is not None: 424 tag = int(tag) 425 result = result[f"tag{tag}"] 426 elif self.tag_average == "macro": 427 if self.average == "none": 428 r = {} 429 for c in self.classes: 430 r[c] = torch.mean( 431 torch.tensor([result[self._key(tag, c)] for tag in self.tags]) 432 ) 433 else: 434 r = torch.mean( 435 torch.tensor([result[f"tag{int(tag)}"] for tag in self.tags]) 436 ) 437 result = r 438 return result 439 440 @abstractmethod 441 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 442 """Calculate the metric value from true and false positive and negative rates. 443 444 Parameters 445 ---------- 446 tp : float 447 true positive rate 448 fp: float 449 false positive rate 450 fn: float 451 false negative rate 452 tn: float 453 true negative rate 454 455 Returns 456 ------- 457 metric : float 458 metric value 459 460 """ 461 462 463class PR_AUC(_ClassificationMetric): 464 """Area under precision-recall curve (not advised for training).""" 465 466 def __init__( 467 self, 468 num_classes: int, 469 ignore_index: int = -100, 470 average: str = "macro", 471 ignored_classes: Set = None, 472 exclusive: bool = False, 473 tag_average: str = "micro", 474 threshold_step: float = 0.1, 475 ): 476 """Initialize the metric. 477 478 Parameters 479 ---------- 480 ignore_index : int, default -100 481 the class index that indicates ignored samples 482 average: {'macro', 'micro', 'none'} 483 method for averaging across classes 484 num_classes : int, optional 485 number of classes (not necessary if main_class is not None) 486 ignored_classes : set, optional 487 a set of class ids to ignore in calculation 488 exclusive: bool, default True 489 set to False for multi-label classification tasks 490 tag_average: {'micro', 'macro', 'none'} 491 method for averaging across meta tags (if given) 492 threshold_step : float, default 0.1 493 the decision threshold step 494 495 """ 496 if exclusive: 497 raise ValueError( 498 "The PR-AUC metric is not implemented for exclusive classification!" 499 ) 500 super().__init__( 501 num_classes, 502 ignore_index, 503 average, 504 ignored_classes, 505 exclusive, 506 tag_average=tag_average, 507 threshold_values=list(np.arange(0, 1, threshold_step)), 508 ) 509 510 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 511 """Calculate the metric value from true and false positive and negative rates.""" 512 precisions = [] 513 recalls = [] 514 for k in sorted(self.threshold_values): 515 precisions.append(tp[k] / (tp[k] + fp[k] + 1e-7)) 516 recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7)) 517 return metrics.auc(x=recalls, y=precisions) 518 519 520class Precision(_ClassificationMetric): 521 """Precision.""" 522 523 def __init__( 524 self, 525 num_classes: int, 526 ignore_index: int = -100, 527 average: str = "macro", 528 ignored_classes: Set = None, 529 exclusive: bool = True, 530 tag_average: str = "micro", 531 threshold_value: Union[float, List] = None, 532 ): 533 """Initialize the class. 534 535 Parameters 536 ---------- 537 ignore_index : int, default -100 538 the class index that indicates ignored samples 539 average: {'macro', 'micro', 'none'} 540 method for averaging across classes 541 num_classes : int, optional 542 number of classes (not necessary if main_class is not None) 543 ignored_classes : set, optional 544 a set of class ids to ignore in calculation 545 exclusive: bool, default True 546 set to False for multi-label classification tasks 547 tag_average: {'micro', 'macro', 'none'} 548 method for averaging across meta tags (if given) 549 threshold_value : float | list, optional 550 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 551 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 552 under the same index 553 554 """ 555 if threshold_value is None: 556 threshold_value = 0.5 557 super().__init__( 558 num_classes, 559 ignore_index, 560 average, 561 ignored_classes, 562 exclusive, 563 tag_average=tag_average, 564 threshold_values=[threshold_value], 565 ) 566 567 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 568 """Calculate the metric value from true and false positive and negative rates.""" 569 k = self.threshold_values[0] 570 if isinstance(k, list): 571 k = ", ".join(map(str, k)) 572 return tp[k] / (tp[k] + fp[k] + 1e-7) 573 574 575class SegmentalPrecision(_ClassificationMetric): 576 """Segmental precision (not advised for training).""" 577 578 segmental = True 579 580 def __init__( 581 self, 582 num_classes: int, 583 ignore_index: int = -100, 584 average: str = "macro", 585 ignored_classes: Set = None, 586 exclusive: bool = True, 587 iou_threshold: float = 0.5, 588 tag_average: str = "micro", 589 threshold_value: Union[float, List] = None, 590 ): 591 """Initialize the class. 592 593 Parameters 594 ---------- 595 ignore_index : int, default -100 596 the class index that indicates ignored samples 597 average: {'macro', 'micro', 'none'} 598 method for averaging across classes 599 num_classes : int, optional 600 number of classes (not necessary if main_class is not None) 601 ignored_classes : set, optional 602 a set of class ids to ignore in calculation 603 exclusive: bool, default True 604 set to False for multi-label classification tasks 605 iou_threshold : float, default 0.5 606 if segmental is true, intervals with IoU larger than this threshold are considered correct 607 tag_average: {'micro', 'macro', 'none'} 608 method for averaging across meta tags (if given) 609 threshold_value : float | list, optional 610 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 611 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 612 under the same index 613 614 """ 615 if threshold_value is None: 616 threshold_value = 0.5 617 super().__init__( 618 num_classes, 619 ignore_index, 620 average, 621 ignored_classes, 622 exclusive, 623 iou_threshold, 624 tag_average, 625 [threshold_value], 626 ) 627 628 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 629 """Calculate the metric value from true and false positive and negative rates.""" 630 k = self.threshold_values[0] 631 if isinstance(k, list): 632 k = ", ".join(map(str, k)) 633 return tp[k] / (tp[k] + fp[k] + 1e-7) 634 635 636class Recall(_ClassificationMetric): 637 """Recall.""" 638 639 def __init__( 640 self, 641 num_classes: int, 642 ignore_index: int = -100, 643 average: str = "macro", 644 ignored_classes: Set = None, 645 exclusive: bool = True, 646 tag_average: str = "micro", 647 threshold_value: Union[float, List] = None, 648 ): 649 """Initialize the class. 650 651 Parameters 652 ---------- 653 ignore_index : int, default -100 654 the class index that indicates ignored samples 655 average: {'macro', 'micro', 'none'} 656 method for averaging across classes 657 num_classes : int, optional 658 number of classes (not necessary if main_class is not None) 659 ignored_classes : set, optional 660 a set of class ids to ignore in calculation 661 exclusive: bool, default True 662 set to False for multi-label classification tasks 663 tag_average: {'micro', 'macro', 'none'} 664 method for averaging across meta tags (if given) 665 threshold_value : float | list, optional 666 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 667 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 668 under the same index 669 670 """ 671 if threshold_value is None: 672 threshold_value = 0.5 673 super().__init__( 674 num_classes, 675 ignore_index, 676 average, 677 ignored_classes, 678 exclusive, 679 tag_average=tag_average, 680 threshold_values=[threshold_value], 681 ) 682 683 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 684 """Calculate the metric value from true and false positive and negative rates.""" 685 k = self.threshold_values[0] 686 if isinstance(k, list): 687 k = ", ".join(map(str, k)) 688 return tp[k] / (tp[k] + fn[k] + 1e-7) 689 690 691class SegmentalRecall(_ClassificationMetric): 692 """Segmental recall (not advised for training).""" 693 694 segmental = True 695 696 def __init__( 697 self, 698 num_classes: int, 699 ignore_index: int = -100, 700 average: str = "macro", 701 ignored_classes: Set = None, 702 exclusive: bool = True, 703 iou_threshold: float = 0.5, 704 tag_average: str = "micro", 705 threshold_value: Union[float, List] = None, 706 ): 707 """Initialize the class. 708 709 Parameters 710 ---------- 711 ignore_index : int, default -100 712 the class index that indicates ignored samples 713 average: {'macro', 'micro', 'none'} 714 method for averaging across classes 715 num_classes : int, optional 716 number of classes (not necessary if main_class is not None) 717 ignored_classes : set, optional 718 a set of class ids to ignore in calculation 719 exclusive: bool, default True 720 set to False for multi-label classification tasks 721 iou_threshold : float, default 0.5 722 if segmental is true, intervals with IoU larger than this threshold are considered correct 723 tag_average: {'micro', 'macro', 'none'} 724 method for averaging across meta tags (if given) 725 threshold_value : float | list, optional 726 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 727 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 728 under the same index 729 730 """ 731 if threshold_value is None: 732 threshold_value = 0.5 733 super().__init__( 734 num_classes, 735 ignore_index, 736 average, 737 ignored_classes, 738 exclusive, 739 iou_threshold, 740 tag_average, 741 [threshold_value], 742 ) 743 744 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 745 """Calculate the metric value from true and false positive and negative rates.""" 746 k = self.threshold_values[0] 747 if isinstance(k, list): 748 k = ", ".join(map(str, k)) 749 return tp[k] / (tp[k] + fn[k] + 1e-7) 750 751 752class F1(_ClassificationMetric): 753 """F1 score.""" 754 755 def __init__( 756 self, 757 num_classes: int, 758 ignore_index: int = -100, 759 average: str = "macro", 760 ignored_classes: Set = None, 761 exclusive: bool = True, 762 tag_average: str = "micro", 763 threshold_value: Union[float, List] = None, 764 integration_interval: int = 0, 765 ): 766 """Initialize the class. 767 768 Parameters 769 ---------- 770 ignore_index : int, default -100 771 the class index that indicates ignored samples 772 average: {'macro', 'micro', 'none'} 773 method for averaging across classes 774 num_classes : int, optional 775 number of classes (not necessary if main_class is not None) 776 ignored_classes : set, optional 777 a set of class ids to ignore in calculation 778 exclusive: bool, default True 779 set to False for multi-label classification tasks 780 tag_average: {'micro', 'macro', 'none'} 781 method for averaging across meta tags (if given) 782 threshold_value : float | list, optional 783 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 784 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 785 under the same index 786 integration_interval : int, default 0 787 the number of frames to integrate over (0 means no integration) 788 789 """ 790 if threshold_value is None: 791 threshold_value = 0.5 792 super().__init__( 793 num_classes, 794 ignore_index, 795 average, 796 ignored_classes, 797 exclusive, 798 tag_average=tag_average, 799 threshold_values=[threshold_value], 800 integration_interval=integration_interval, 801 ) 802 803 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 804 """Calculate the metric value from true and false positive and negative rates.""" 805 if self.optimize: 806 scores = [] 807 for k in self.threshold_values: 808 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 809 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 810 scores.append(2 * recall * precision / (recall + precision + 1e-7)) 811 f1 = max(scores) 812 else: 813 k = self.threshold_values[0] 814 if isinstance(k, list): 815 k = ", ".join(map(str, k)) 816 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 817 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 818 f1 = 2 * recall * precision / (recall + precision + 1e-7) 819 return f1 820 821 822class SegmentalF1(_ClassificationMetric): 823 """Segmental F1 score (not advised for training).""" 824 825 segmental = True 826 827 def __init__( 828 self, 829 num_classes: int, 830 ignore_index: int = -100, 831 average: str = "macro", 832 ignored_classes: Set = None, 833 exclusive: bool = True, 834 iou_threshold: float = 0.5, 835 tag_average: str = "micro", 836 threshold_value: Union[float, List] = None, 837 ): 838 """Initialize the class. 839 840 Parameters 841 ---------- 842 ignore_index : int, default -100 843 the class index that indicates ignored samples 844 average: {'macro', 'micro', 'none'} 845 method for averaging across classes 846 num_classes : int, optional 847 number of classes (not necessary if main_class is not None) 848 ignored_classes : set, optional 849 a set of class ids to ignore in calculation 850 exclusive: bool, default True 851 set to False for multi-label classification tasks 852 iou_threshold : float, default 0.5 853 if segmental is true, intervals with IoU larger than this threshold are considered correct 854 tag_average: {'micro', 'macro', 'none'} 855 method for averaging across meta tags (if given) 856 threshold_value : float | list, optional 857 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 858 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 859 under the same index 860 861 """ 862 if threshold_value is None: 863 threshold_value = 0.5 864 super().__init__( 865 num_classes, 866 ignore_index, 867 average, 868 ignored_classes, 869 exclusive, 870 iou_threshold, 871 tag_average, 872 [threshold_value], 873 ) 874 875 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 876 """Calculate the metric value from true and false positive and negative rates.""" 877 if self.optimize: 878 scores = [] 879 for k in self.threshold_values: 880 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 881 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 882 scores.append(2 * recall * precision / (recall + precision + 1e-7)) 883 f1 = max(scores) 884 else: 885 k = self.threshold_values[0] 886 if isinstance(k, list): 887 k = ", ".join(map(str, k)) 888 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 889 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 890 f1 = 2 * recall * precision / (recall + precision + 1e-7) 891 return f1 892 893 894class Fbeta(_ClassificationMetric): 895 """F-beta score.""" 896 897 def __init__( 898 self, 899 beta: float = 1, 900 ignore_index: int = -100, 901 average: str = "macro", 902 num_classes: int = None, 903 ignored_classes: Set = None, 904 tag_average: str = "micro", 905 exclusive: bool = True, 906 threshold_value: float = 0.5, 907 ): 908 """Initialize the class. 909 910 Parameters 911 ---------- 912 beta : float, default 1 913 the beta parameter 914 ignore_index : int, default -100 915 the class index that indicates ignored samples 916 average: {'macro', 'micro', 'none'} 917 method for averaging across classes 918 num_classes : int, optional 919 number of classes (not necessary if main_class is not None) 920 ignored_classes : set, optional 921 a set of class ids to ignore in calculation 922 exclusive: bool, default True 923 set to False for multi-label classification tasks 924 tag_average: {'micro', 'macro', 'none'} 925 method for averaging across meta tags (if given) 926 threshold_value : float | list, optional 927 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 928 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 929 under the same index 930 931 """ 932 if threshold_value is None: 933 threshold_value = 0.5 934 self.beta2 = beta**2 935 super().__init__( 936 num_classes, 937 ignore_index, 938 average, 939 ignored_classes, 940 exclusive, 941 tag_average=tag_average, 942 threshold_values=[threshold_value], 943 ) 944 945 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 946 """Calculate the metric value from true and false positive and negative rates.""" 947 if self.optimize: 948 scores = [] 949 for k in self.threshold_values: 950 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 951 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 952 scores.append( 953 ( 954 (1 + self.beta2) 955 * precision 956 * recall 957 / (self.beta2 * precision + recall + 1e-7) 958 ) 959 ) 960 f1 = max(scores) 961 else: 962 k = self.threshold_values[0] 963 if isinstance(k, list): 964 k = ", ".join(map(str, k)) 965 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 966 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 967 f1 = ( 968 (1 + self.beta2) 969 * precision 970 * recall 971 / (self.beta2 * precision + recall + 1e-7) 972 ) 973 return f1 974 975 976class SegmentalFbeta(_ClassificationMetric): 977 """Segmental F-beta score (not advised for training).""" 978 979 segmental = True 980 981 def __init__( 982 self, 983 beta: float = 1, 984 ignore_index: int = -100, 985 average: str = "macro", 986 num_classes: int = None, 987 ignored_classes: Set = None, 988 iou_threshold: float = 0.5, 989 tag_average: str = "micro", 990 exclusive: bool = True, 991 threshold_value: float = 0.5, 992 ): 993 """Initialize the class. 994 995 Parameters 996 ---------- 997 beta : float, default 1 998 the beta parameter 999 ignore_index : int, default -100 1000 the class index that indicates ignored samples 1001 average: {'macro', 'micro', 'none'} 1002 method for averaging across classes 1003 num_classes : int, optional 1004 number of classes (not necessary if main_class is not None) 1005 ignored_classes : set, optional 1006 a set of class ids to ignore in calculation 1007 exclusive: bool, default True 1008 set to False for multi-label classification tasks 1009 iou_threshold : float, default 0.5 1010 if segmental is true, intervals with IoU larger than this threshold are considered correct 1011 tag_average: {'micro', 'macro', 'none'} 1012 method for averaging across meta tags (if given) 1013 threshold_value : float | list, optional 1014 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 1015 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 1016 under the same index 1017 1018 """ 1019 if threshold_value is None: 1020 threshold_value = 0.5 1021 self.beta2 = beta**2 1022 super().__init__( 1023 num_classes, 1024 ignore_index, 1025 average, 1026 ignored_classes, 1027 exclusive, 1028 iou_threshold=iou_threshold, 1029 tag_average=tag_average, 1030 threshold_values=[threshold_value], 1031 ) 1032 1033 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1034 """Calculate the metric value from true and false positive and negative rates.""" 1035 if self.optimize: 1036 scores = [] 1037 for k in self.threshold_values: 1038 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 1039 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 1040 scores.append( 1041 ( 1042 (1 + self.beta2) 1043 * precision 1044 * recall 1045 / (self.beta2 * precision + recall + 1e-7) 1046 ) 1047 ) 1048 f1 = max(scores) 1049 else: 1050 k = self.threshold_values[0] 1051 if isinstance(k, list): 1052 k = ", ".join(map(str, k)) 1053 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 1054 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 1055 f1 = ( 1056 (1 + self.beta2) 1057 * precision 1058 * recall 1059 / (self.beta2 * precision + recall + 1e-7) 1060 ) 1061 return f1 1062 1063 1064class _SemiSegmentalMetric(_ClassificationMetric): 1065 def __init__( 1066 self, 1067 num_classes: int, 1068 ignore_index: int = -100, 1069 ignored_classes: Set = None, 1070 exclusive: bool = True, 1071 average: str = "macro", 1072 tag_average: str = "micro", 1073 delta: int = 0, 1074 smooth_interval: int = 0, 1075 iou_threshold_long: float = 0.5, 1076 iou_threshold_short: float = 0.5, 1077 short_length: int = 30, 1078 long_length: int = 300, 1079 threshold_values: List = None, 1080 ) -> None: 1081 """Initialize the class. 1082 1083 Parameters 1084 ---------- 1085 num_classes : int 1086 the number of classes in the dataset 1087 ignore_index : int, default -100 1088 the ground truth label to ignore 1089 ignored_classes : set, optional 1090 the class indices to ignore in computation 1091 exclusive : bool, default True 1092 `False` for multi-label classification tasks 1093 average : {"macro", "micro", "none"} 1094 the method to average the results over classes 1095 tag_average : {"macro", "micro", "none"} 1096 the method to average the results over meta tags (if given) 1097 delta : int, default 0 1098 the number of frames to add to each ground truth interval before computing the intersection, 1099 see description of the class for details 1100 smooth_interval : int, default 0 1101 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1102 see description of the class for details 1103 iou_threshold_long : float, default 0.5 1104 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1105 see description of the class for details 1106 iou_threshold_short : float, default 0.5 1107 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1108 see description of the class for details 1109 short_length : int, default 30 1110 the threshold number of frames for short intervals that will have an intersection threshold of 1111 `iou_threshold_short`, see description of the class for details 1112 long_length : int, default 300 1113 the threshold number of frames for long intervals that will have an intersection threshold of 1114 `iou_threshold_long`, see description of the class for details 1115 threshold_values : list, optional 1116 the decision threshold values (cannot be defined for exclusive classification, and 0.5 be default) 1117 1118 """ 1119 if smooth_interval > 0 and exclusive: 1120 warnings.warn( 1121 "Smoothing is not implemented for datasets with exclusive classification! Setting smooth_interval to 0..." 1122 ) 1123 if threshold_values == [None]: 1124 threshold_values = [0.5] 1125 super().__init__( 1126 num_classes, 1127 ignore_index, 1128 average, 1129 ignored_classes, 1130 exclusive, 1131 tag_average=tag_average, 1132 threshold_values=threshold_values, 1133 ) 1134 self.delta = delta 1135 self.random_sampling = False 1136 self.smooth_interval = smooth_interval 1137 self.threshold_function = self._sigmoid_threshold_function( 1138 low_threshold=iou_threshold_short, 1139 high_threshold=iou_threshold_long, 1140 low_length=short_length, 1141 high_length=long_length, 1142 ) 1143 1144 def update( 1145 self, 1146 predicted: torch.Tensor, 1147 target: torch.Tensor, 1148 tags: torch.Tensor, 1149 ) -> None: 1150 """Update the internal state (with a batch). 1151 1152 Parameters 1153 ---------- 1154 predicted : torch.Tensor 1155 the main prediction tensor generated by the model 1156 ssl_predicted : torch.Tensor 1157 the SSL prediction tensor generated by the model 1158 target : torch.Tensor 1159 the corresponding main target tensor 1160 ssl_target : torch.Tensor 1161 the corresponding SSL target tensor 1162 tags : torch.Tensor 1163 the tensor of meta tags (or `None`, if tags are not given) 1164 1165 """ 1166 if self.exclusive: 1167 predicted = torch.max(predicted, 1)[1] 1168 predicted = torch.cat( 1169 [ 1170 copy(predicted), 1171 -100 1172 * torch.ones((*predicted.shape[:-1], self.delta + 1)).to( 1173 predicted.device 1174 ), 1175 ], 1176 dim=-1, 1177 ) 1178 target = torch.cat( 1179 [ 1180 copy(target), 1181 -100 1182 * torch.ones((*target.shape[:-1], self.delta + 1)).to(target.device), 1183 ], 1184 dim=-1, 1185 ) 1186 if self.exclusive: 1187 predicted = predicted.flatten() 1188 target = target.flatten() 1189 else: 1190 predicted = predicted.transpose(1, 2).reshape(-1, target.shape[1]) 1191 target = target.transpose(1, 2).reshape(-1, predicted.shape[1]) 1192 if self.tag_average == "micro" or tags is None or tags[0] is None: 1193 tag_set = {None} 1194 else: 1195 tag_set = set(tags) 1196 self.tags.update(tag_set) 1197 if self.smooth_interval > 0: 1198 target = self._smooth(target, self.smooth_interval) 1199 for thr in self.threshold_values: 1200 key = thr 1201 for t in tag_set: 1202 if t is None: 1203 predicted_t = deepcopy(predicted) 1204 target_t = target 1205 else: 1206 predicted_t = deepcopy(predicted[tags == t]) 1207 target_t = target[tags == t] 1208 if not self.exclusive: 1209 if isinstance(thr, list): 1210 for i, tt in enumerate(thr): 1211 predicted_t[i, :] = (predicted_t[i, :] > tt).int() 1212 key = ", ".join(map(str, thr)) 1213 else: 1214 predicted_t = (predicted_t > thr).int() 1215 if self.smooth_interval > 0: 1216 predicted_t = self._smooth(predicted_t, self.smooth_interval) 1217 for c in self.classes: 1218 if not self.random_sampling: 1219 if self.exclusive: 1220 target_intervals = self._get_intervals(target_t == c) 1221 else: 1222 target_intervals = self._get_intervals(target_t[:, c] == 1) 1223 target_lengths = [ 1224 end - start for start, end in target_intervals 1225 ] 1226 target_intervals = [ 1227 [ 1228 max(start - self.delta, 0), 1229 min(end + self.delta, len(target_t)), 1230 ] 1231 for start, end in target_intervals 1232 ] 1233 else: 1234 if self.exclusive: 1235 target_arr = target_t == c 1236 else: 1237 target_arr = target_t[:, c] == 1 1238 target_points = torch.where(target_arr)[0][::20] 1239 target_intervals = [] 1240 for p in target_points: 1241 target_intervals.append( 1242 [ 1243 max(0, p - self.delta), 1244 min(len(target_arr), p + self.delta), 1245 ] 1246 ) 1247 target_lengths = [ 1248 end - start for start, end in target_intervals 1249 ] 1250 if self.exclusive: 1251 predicted_arr = predicted_t == c 1252 else: 1253 predicted_arr = predicted_t[:, c] == 1 1254 for interval, l in zip(target_intervals, target_lengths): 1255 intersection = torch.sum( 1256 predicted_arr[interval[0] : interval[1]] 1257 ) 1258 IoU = intersection / l 1259 if IoU >= self.threshold_function(l): 1260 self.tp[self._key(t, c)][key] += 1 1261 else: 1262 self.fn[self._key(t, c)][key] += 1 1263 1264 if not self.random_sampling: 1265 if self.exclusive: 1266 predicted_intervals = self._get_intervals(predicted_t == c) 1267 else: 1268 predicted_intervals = self._get_intervals( 1269 predicted_t[:, c] == 1 1270 ) 1271 predicted_intervals_delta = [ 1272 [ 1273 max(start - self.delta, 0), 1274 min(end + self.delta, len(target_t)), 1275 ] 1276 for start, end in predicted_intervals 1277 ] 1278 else: 1279 if self.exclusive: 1280 predicted_arr = predicted_t == c 1281 else: 1282 predicted_arr = predicted_t[:, c] == 1 1283 predicted_points = torch.where(predicted_arr)[0][::20] 1284 predicted_intervals = [] 1285 for p in predicted_points: 1286 predicted_intervals.append( 1287 [ 1288 max(0, p - self.delta), 1289 min(len(predicted_arr), p + self.delta), 1290 ] 1291 ) 1292 predicted_intervals_delta = predicted_intervals 1293 if self.exclusive: 1294 target_arr = target_t == c 1295 target_arr[target_t == -100] = -100 1296 else: 1297 target_arr = target_t[:, c] 1298 for interval, interval_delta in zip( 1299 predicted_intervals, predicted_intervals_delta 1300 ): 1301 if torch.sum( 1302 target_arr[interval_delta[0] : interval_delta[1]] != -100 1303 ) < 0.3 * (interval[1] - interval[0]): 1304 continue 1305 l = torch.sum(target_arr[interval[0] : interval[1]] != -100) 1306 intersection = torch.sum( 1307 target_arr[interval_delta[0] : interval_delta[1]] == 1 1308 ) 1309 IoU = intersection / l 1310 if IoU >= self.threshold_function(l): 1311 self.tn[self._key(t, c)][key] += 1 1312 else: 1313 self.fp[self._key(t, c)][key] += 1 1314 1315 1316class SemiSegmentalRecall(_SemiSegmentalMetric): 1317 """Semi-segmental recall (not advised for training). 1318 1319 A metric in-between segmental and frame-wise recall. 1320 1321 This metric follows the following algorithm: 1322 1) smooth over too-short intervals, both in ground truth and in prediction (first remove 1323 groups of zeros shorter than `smooth_interval` and then do the same with groups of ones), 1324 2) add `delta` frames to each ground truth interval at both ends and count the number of predicted 1325 positive frames at the resulting intervals (intersection), 1326 3) calculate the threshold for each interval as 1327 `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where 1328 `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval 1329 before `delta` was added, 1330 4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`), 1331 and otherwise as false negative (`FN`), 1332 5) the final metric value is computed as `TP / (TP + FN)`. 1333 """ 1334 1335 def __init__( 1336 self, 1337 num_classes: int, 1338 ignore_index: int = -100, 1339 ignored_classes: Set = None, 1340 exclusive: bool = True, 1341 average: str = "macro", 1342 tag_average: str = "micro", 1343 delta: int = 0, 1344 smooth_interval: int = 0, 1345 iou_threshold_long: float = 0.5, 1346 iou_threshold_short: float = 0.5, 1347 short_length: int = 30, 1348 long_length: int = 300, 1349 threshold_value: Union[float, List] = None, 1350 ) -> None: 1351 """Initialize the class. 1352 1353 Parameters 1354 ---------- 1355 num_classes : int 1356 the number of classes in the dataset 1357 ignore_index : int, default -100 1358 the ground truth label to ignore 1359 ignored_classes : set, optional 1360 the class indices to ignore in computation 1361 exclusive : bool, default True 1362 `False` for multi-label classification tasks 1363 average : {"macro", "micro", "none"} 1364 the method to average the results over classes 1365 tag_average : {"macro", "micro", "none"} 1366 the method to average the results over meta tags (if given) 1367 delta : int, default 0 1368 the number of frames to add to each ground truth interval before computing the intersection, 1369 see description of the class for details 1370 smooth_interval : int, default 0 1371 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1372 see description of the class for details 1373 iou_threshold_long : float, default 0.5 1374 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1375 see description of the class for details 1376 iou_threshold_short : float, default 0.5 1377 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1378 see description of the class for details 1379 short_length : int, default 30 1380 the threshold number of frames for short intervals that will have an intersection threshold of 1381 `iou_threshold_short`, see description of the class for details 1382 long_length : int, default 300 1383 the threshold number of frames for long intervals that will have an intersection threshold of 1384 `iou_threshold_long`, see description of the class for details 1385 threshold_value : float | list, optional 1386 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1387 1388 """ 1389 super().__init__( 1390 num_classes, 1391 ignore_index, 1392 ignored_classes, 1393 exclusive, 1394 average, 1395 tag_average, 1396 delta, 1397 smooth_interval, 1398 iou_threshold_long, 1399 iou_threshold_short, 1400 short_length, 1401 long_length, 1402 [threshold_value], 1403 ) 1404 1405 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1406 """Calculate the metric value from true and false positive and negative rates.""" 1407 k = self.threshold_values[0] 1408 if isinstance(k, list): 1409 k = ", ".join(map(str, k)) 1410 return tp[k] / (tp[k] + fn[k] + 1e-7) 1411 1412 1413class SemiSegmentalPrecision(_SemiSegmentalMetric): 1414 """Semi-segmental precision (not advised for training). 1415 1416 A metric in-between segmental and frame-wise precision. 1417 1418 This metric follows the following algorithm: 1419 1) smooth over too-short intervals, both in ground truth and in prediction (first remove 1420 groups of zeros shorter than `smooth_interval` and then do the same with groups of ones), 1421 2) add `delta` frames to each predicted interval at both ends and count the number of ground truth 1422 positive frames at the resulting intervals (intersection), 1423 3) calculate the threshold for each interval as 1424 `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where 1425 `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval 1426 before `delta` was added, 1427 4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`), 1428 and otherwise as false positive (`FP`), 1429 5) the final metric value is computed as `TP / (TP + FP)`. 1430 1431 """ 1432 def __init__( 1433 self, 1434 num_classes: int, 1435 ignore_index: int = -100, 1436 ignored_classes: Set = None, 1437 exclusive: bool = True, 1438 average: str = "macro", 1439 tag_average: str = "micro", 1440 delta: int = 0, 1441 smooth_interval: int = 0, 1442 iou_threshold_long: float = 0.5, 1443 iou_threshold_short: float = 0.5, 1444 short_length: int = 30, 1445 long_length: int = 300, 1446 threshold_value: Union[float, List] = None, 1447 ) -> None: 1448 """Initialize the class. 1449 1450 Parameters 1451 ---------- 1452 num_classes : int 1453 the number of classes in the dataset 1454 ignore_index : int, default -100 1455 the ground truth label to ignore 1456 ignored_classes : set, optional 1457 the class indices to ignore in computation 1458 exclusive : bool, default True 1459 `False` for multi-label classification tasks 1460 average : {"macro", "micro", "none"} 1461 the method to average the results over classes 1462 tag_average : {"macro", "micro", "none"} 1463 the method to average the results over meta tags (if given) 1464 delta : int, default 0 1465 the number of frames to add to each ground truth interval before computing the intersection, 1466 see description of the class for details 1467 smooth_interval : int, default 0 1468 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1469 see description of the class for details 1470 iou_threshold_long : float, default 0.5 1471 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1472 see description of the class for details 1473 iou_threshold_short : float, default 0.5 1474 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1475 see description of the class for details 1476 short_length : int, default 30 1477 the threshold number of frames for short intervals that will have an intersection threshold of 1478 `iou_threshold_short`, see description of the class for details 1479 long_length : int, default 300 1480 the threshold number of frames for long intervals that will have an intersection threshold of 1481 `iou_threshold_long`, see description of the class for details 1482 threshold_value : float | list, optional 1483 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1484 1485 """ 1486 super().__init__( 1487 num_classes, 1488 ignore_index, 1489 ignored_classes, 1490 exclusive, 1491 average, 1492 tag_average, 1493 delta, 1494 smooth_interval, 1495 iou_threshold_long, 1496 iou_threshold_short, 1497 short_length, 1498 long_length, 1499 [threshold_value], 1500 ) 1501 1502 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1503 """Calculate the metric value from true and false positive and negative rates.""" 1504 k = self.threshold_values[0] 1505 if isinstance(k, list): 1506 k = ", ".join(map(str, k)) 1507 return tn[k] / (tn[k] + fp[k] + 1e-7) 1508 1509 1510class SemiSegmentalF1(_SemiSegmentalMetric): 1511 """The F1 score for semi-segmental recall and precision (not advised for training).""" 1512 1513 def __init__( 1514 self, 1515 num_classes: int, 1516 ignore_index: int = -100, 1517 ignored_classes: Set = None, 1518 exclusive: bool = True, 1519 average: str = "macro", 1520 tag_average: str = "micro", 1521 delta: int = 0, 1522 smooth_interval: int = 0, 1523 iou_threshold_long: float = 0.5, 1524 iou_threshold_short: float = 0.5, 1525 short_length: int = 30, 1526 long_length: int = 300, 1527 threshold_value: Union[float, List] = None, 1528 ) -> None: 1529 """Initialize the class. 1530 1531 Parameters 1532 ---------- 1533 num_classes : int 1534 the number of classes in the dataset 1535 ignore_index : int, default -100 1536 the ground truth label to ignore 1537 ignored_classes : set, optional 1538 the class indices to ignore in computation 1539 exclusive : bool, default True 1540 `False` for multi-label classification tasks 1541 average : {"macro", "micro", "none"} 1542 the method to average the results over classes 1543 tag_average : {"macro", "micro", "none"} 1544 the method to average the results over meta tags (if given) 1545 delta : int, default 0 1546 the number of frames to add to each ground truth interval before computing the intersection, 1547 see description of the class for details 1548 smooth_interval : int, default 0 1549 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1550 see description of the class for details 1551 iou_threshold_long : float, default 0.5 1552 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1553 see description of the class for details 1554 iou_threshold_short : float, default 0.5 1555 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1556 see description of the class for details 1557 short_length : int, default 30 1558 the threshold number of frames for short intervals that will have an intersection threshold of 1559 `iou_threshold_short`, see description of the class for details 1560 long_length : int, default 300 1561 the threshold number of frames for long intervals that will have an intersection threshold of 1562 `iou_threshold_long`, see description of the class for details 1563 threshold_value : float | list, optional 1564 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1565 1566 """ 1567 super().__init__( 1568 num_classes, 1569 ignore_index, 1570 ignored_classes, 1571 exclusive, 1572 average, 1573 tag_average, 1574 delta, 1575 smooth_interval, 1576 iou_threshold_long, 1577 iou_threshold_short, 1578 short_length, 1579 long_length, 1580 [threshold_value], 1581 ) 1582 1583 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1584 """Calculate the metric value from true and false positive and negative rates.""" 1585 if self.optimize: 1586 scores = [] 1587 for k in self.threshold_values: 1588 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 1589 precision = tn[k] / (tn[k] + fp[k] + 1e-7) 1590 scores.append(2 * recall * precision / (recall + precision + 1e-7)) 1591 f1 = max(scores) 1592 else: 1593 k = self.threshold_values[0] 1594 if isinstance(k, list): 1595 k = ", ".join(map(str, k)) 1596 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 1597 precision = tn[k] / (tn[k] + fp[k] + 1e-7) 1598 f1 = 2 * recall * precision / (recall + precision + 1e-7) 1599 return f1 1600 1601 1602class SemiSegmentalPR_AUC(_SemiSegmentalMetric): 1603 """The area under the precision-recall curve for semi-segmental metrics (not advised for training).""" 1604 1605 def __init__( 1606 self, 1607 num_classes: int, 1608 ignore_index: int = -100, 1609 ignored_classes: Set = None, 1610 exclusive: bool = True, 1611 average: str = "macro", 1612 tag_average: str = "micro", 1613 delta: int = 0, 1614 smooth_interval: int = 0, 1615 iou_threshold_long: float = 0.5, 1616 iou_threshold_short: float = 0.5, 1617 short_length: int = 30, 1618 long_length: int = 300, 1619 threshold_step: float = 0.1, 1620 ) -> None: 1621 super().__init__( 1622 num_classes, 1623 ignore_index, 1624 ignored_classes, 1625 exclusive, 1626 average, 1627 tag_average, 1628 delta, 1629 smooth_interval, 1630 iou_threshold_long, 1631 iou_threshold_short, 1632 short_length, 1633 long_length, 1634 list(np.arange(0, 1, threshold_step)), 1635 ) 1636 1637 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1638 """Calculate the metric value from true and false positive and negative rates.""" 1639 precisions = [] 1640 recalls = [] 1641 for k in sorted(self.threshold_values): 1642 precisions.append(tn[k] / (tn[k] + fp[k] + 1e-7)) 1643 recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7)) 1644 return metrics.auc(x=recalls, y=precisions) 1645 1646 1647class Accuracy(Metric): 1648 """Accuracy.""" 1649 1650 def __init__(self, ignore_index=-100): 1651 """Initialize the class. 1652 1653 Parameters 1654 ---------- 1655 ignore_index: int 1656 the class index that indicates ignored sample 1657 1658 """ 1659 super().__init__() 1660 self.ignore_index = ignore_index 1661 1662 def reset(self) -> None: 1663 """Reset the internal state (at the beginning of an epoch).""" 1664 self.total = 0 1665 self.correct = 0 1666 1667 def calculate(self) -> float: 1668 """Calculate the metric value. 1669 1670 Returns 1671 ------- 1672 metric : float 1673 metric value 1674 1675 """ 1676 return self.correct / (self.total + 1e-7) 1677 1678 def update( 1679 self, 1680 predicted: torch.Tensor, 1681 target: torch.Tensor, 1682 tags: torch.Tensor = None, 1683 ) -> None: 1684 """Update the internal state (with a batch). 1685 1686 Parameters 1687 ---------- 1688 predicted : torch.Tensor 1689 the main prediction tensor generated by the model 1690 ssl_predicted : torch.Tensor 1691 the SSL prediction tensor generated by the model 1692 target : torch.Tensor 1693 the corresponding main target tensor 1694 ssl_target : torch.Tensor 1695 the corresponding SSL target tensor 1696 tags : torch.Tensor 1697 the tensor of meta tags (or `None`, if tags are not given) 1698 1699 """ 1700 mask = target != self.ignore_index 1701 self.total += torch.sum(mask) 1702 self.correct += torch.sum((target == predicted)[mask]) 1703 1704 1705class Count(Metric): 1706 """Fraction of samples labeled by the model as a class.""" 1707 1708 def __init__(self, classes: Set, exclusive: bool = True): 1709 """Initialize the class. 1710 1711 Parameters 1712 ---------- 1713 classes : set 1714 the set of classes to count 1715 exclusive: bool, default True 1716 set to False for multi-label classification tasks 1717 1718 """ 1719 super().__init__() 1720 self.classes = classes 1721 self.exclusive = exclusive 1722 1723 def reset(self) -> None: 1724 """Reset the internal state (at the beginning of an epoch).""" 1725 self.count = defaultdict(lambda: 0) 1726 self.total = 0 1727 1728 def update( 1729 self, 1730 predicted: torch.Tensor, 1731 target: torch.Tensor, 1732 tags: torch.Tensor, 1733 ) -> None: 1734 """Update the internal state (with a batch). 1735 1736 Parameters 1737 ---------- 1738 predicted : torch.Tensor 1739 the main prediction tensor generated by the model 1740 ssl_predicted : torch.Tensor 1741 the SSL prediction tensor generated by the model 1742 target : torch.Tensor 1743 the corresponding main target tensor 1744 ssl_target : torch.Tensor 1745 the corresponding SSL target tensor 1746 tags : torch.Tensor 1747 the tensor of meta tags (or `None`, if tags are not given) 1748 1749 """ 1750 if self.exclusive: 1751 for c in self.classes: 1752 self.count[c] += torch.sum(predicted == c) 1753 self.total += torch.numel(predicted) 1754 else: 1755 for c in self.classes: 1756 self.count[c] += torch.sum(predicted[:, c, :] == 1) 1757 self.total += torch.numel(predicted[:, 0, :]) 1758 1759 def calculate(self) -> Dict: 1760 """Calculate the metric (at the end of an epoch). 1761 1762 Returns 1763 ------- 1764 result : dict 1765 a dictionary where the keys are class indices and the values are class metric values 1766 1767 """ 1768 for c in self.classes: 1769 self.count[c] = self.count[c] / (self.total + 1e-7) 1770 return dict(self.count) 1771 1772 1773class EditDistance(Metric): 1774 """Edit distance (not advised for training). 1775 1776 Normalized by the length of the sequences. 1777 """ 1778 1779 def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None: 1780 """Initialize the class. 1781 1782 Parameters 1783 ---------- 1784 ignore_index : int, default -100 1785 the class index that indicates samples that should be ignored 1786 exclusive : bool, default True 1787 set to False for multi-label classification tasks 1788 1789 """ 1790 super().__init__() 1791 self.ignore_index = ignore_index 1792 self.exclusive = exclusive 1793 1794 def reset(self) -> None: 1795 """Reset the internal state (at the beginning of an epoch).""" 1796 self.edit_distance = 0 1797 self.total = 0 1798 1799 def update( 1800 self, 1801 predicted: torch.Tensor, 1802 target: torch.Tensor, 1803 tags: torch.Tensor, 1804 ) -> None: 1805 """Update the internal state (with a batch). 1806 1807 Parameters 1808 ---------- 1809 predicted : torch.Tensor 1810 the main prediction tensor generated by the model 1811 ssl_predicted : torch.Tensor 1812 the SSL prediction tensor generated by the model 1813 target : torch.Tensor 1814 the corresponding main target tensor 1815 ssl_target : torch.Tensor 1816 the corresponding SSL target tensor 1817 tags : torch.Tensor 1818 the tensor of meta tags (or `None`, if tags are not given) 1819 1820 """ 1821 mask = target != self.ignore_index 1822 self.total += torch.sum(mask) 1823 if self.exclusive: 1824 predicted = predicted[mask].flatten() 1825 target = target[mask].flatten() 1826 self.edit_distance += editdistance.eval( 1827 predicted.detach().cpu().numpy(), target.detach().cpu().numpy() 1828 ) 1829 else: 1830 for c in range(target.shape[1]): 1831 predicted_class = predicted[:, c, :][mask[:, c, :]].flatten() 1832 target_class = target[:, c, :][mask[:, c, :]].flatten() 1833 self.edit_distance += editdistance.eval( 1834 predicted_class.detach().cpu().tolist(), 1835 target_class.detach().cpu().tolist(), 1836 ) 1837 1838 def _is_equal(self, a, b): 1839 """Compare while ignoring samples marked with ignore_index.""" 1840 if self.ignore_index in [a, b] or a == b: 1841 return True 1842 else: 1843 return False 1844 1845 def calculate(self) -> float: 1846 """Calculate the metric (at the end of an epoch). 1847 1848 Returns 1849 ------- 1850 result : float 1851 the metric value 1852 1853 """ 1854 return self.edit_distance / (self.total + 1e-7) 1855 1856 1857class mAP(Metric): 1858 """Mean average precision (segmental) (not advised for training).""" 1859 1860 needs_raw_data = True 1861 1862 def __init__( 1863 self, 1864 exclusive, 1865 num_classes, 1866 average="macro", 1867 iou_threshold=0.5, 1868 threshold_value=0.5, 1869 ignored_classes=None, 1870 ): 1871 """Initialize the class. 1872 1873 Parameters 1874 ---------- 1875 exclusive : bool 1876 set to False for multi-label classification tasks 1877 num_classes : int 1878 the number of classes 1879 average : {"macro", "micro", "none"} 1880 the type of averaging to perform 1881 iou_threshold : float, default 0.5 1882 the IoU threshold for matching 1883 threshold_value : float, default 0.5 1884 the threshold value for binarization 1885 ignored_classes : list, default None 1886 the list of classes to ignore 1887 1888 """ 1889 if ignored_classes is None: 1890 ignored_classes = [] 1891 self.average = average 1892 self.iou_threshold = iou_threshold 1893 self.threshold = threshold_value 1894 self.exclusive = exclusive 1895 self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes] 1896 super().__init__() 1897 1898 def match(self, lst, ratio, ground): 1899 """Given a list of proposals, match them to the ground truth boxes.""" 1900 lst = sorted(lst, key=lambda x: x[2]) 1901 1902 def overlap(prop, ground): 1903 s_p, e_p, _ = prop 1904 s_g, e_g, _ = ground 1905 return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g)) 1906 1907 cos_map = [-1 for x in range(len(lst))] 1908 count_map = [0 for x in range(len(ground))] 1909 1910 for x in range(len(lst)): 1911 for y in range(len(ground)): 1912 if overlap(lst[x], ground[y]) < ratio: 1913 continue 1914 if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]): 1915 continue 1916 cos_map[x] = y 1917 if cos_map[x] != -1: 1918 count_map[cos_map[x]] += 1 1919 positive = sum([(x > 0) for x in count_map]) 1920 return cos_map, count_map, positive, [x[2] for x in lst] 1921 1922 def reset(self) -> None: 1923 """Reset the internal state (at the beginning of an epoch).""" 1924 self.count_map = defaultdict(lambda: []) 1925 self.positive = defaultdict(lambda: 0) 1926 self.cos_map = defaultdict(lambda: []) 1927 self.confidence = defaultdict(lambda: []) 1928 1929 def calc_pr(self, positive, proposal, ground): 1930 """Get precision.""" 1931 if proposal == 0: 1932 return 0, 0 1933 if ground == 0: 1934 return 0, 0 1935 return (1.0 * positive) / proposal, (1.0 * positive) / ground 1936 1937 def calculate(self) -> Union[float, Dict]: 1938 """Calculate the metric (at the end of an epoch).""" 1939 if self.average == "micro": 1940 confidence = [] 1941 count_map = [] 1942 cos_map = [] 1943 positive = sum(self.positive.values()) 1944 for key in self.count_map.keys(): 1945 confidence += self.confidence[key] 1946 cos_map += list(np.array(self.cos_map[key]) + len(count_map)) 1947 count_map += self.count_map[key] 1948 return self.ap(cos_map, count_map, positive, confidence) 1949 results = { 1950 key: self.ap( 1951 self.cos_map[key], 1952 self.count_map[key], 1953 self.positive[key], 1954 self.confidence[key], 1955 ) 1956 for key in self.count_map.keys() 1957 } 1958 if self.average == "none": 1959 return results 1960 else: 1961 return float(np.mean(list(results.values()))) 1962 1963 def ap(self, cos_map, count_map, positive, confidence): 1964 """Compute average precision.""" 1965 indices = np.argsort(confidence) 1966 cos_map = list(np.array(cos_map)[indices]) 1967 score = 0 1968 number_proposal = len(cos_map) 1969 number_ground = len(count_map) 1970 old_precision, old_recall = self.calc_pr( 1971 positive, number_proposal, number_ground 1972 ) 1973 1974 for x in range(len(cos_map)): 1975 number_proposal -= 1 1976 if cos_map[x] == -1: 1977 continue 1978 count_map[cos_map[x]] -= 1 1979 if count_map[cos_map[x]] == 0: 1980 positive -= 1 1981 1982 precision, recall = self.calc_pr(positive, number_proposal, number_ground) 1983 if precision > old_precision: 1984 old_precision = precision 1985 score += old_precision * (old_recall - recall) 1986 old_recall = recall 1987 return score 1988 1989 def _get_intervals( 1990 self, tensor: torch.Tensor, probability: torch.Tensor = None 1991 ) -> Union[Tuple, torch.Tensor]: 1992 """Get `True` group beginning and end indices from a boolean tensor and average probability over these intervals.""" 1993 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 1994 true_indices = torch.where(output)[0] 1995 starts = torch.tensor( 1996 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 1997 ) 1998 ends = torch.tensor( 1999 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 2000 ) 2001 confidence = torch.tensor( 2002 [probability[indices == i].mean() for i in true_indices] 2003 ) 2004 return torch.stack([starts, ends, confidence]).T 2005 2006 def update( 2007 self, 2008 predicted: torch.Tensor, 2009 target: torch.Tensor, 2010 tags: torch.Tensor, 2011 ) -> None: 2012 """Update the state (at the end of each batch).""" 2013 predicted = torch.cat( 2014 [ 2015 copy(predicted), 2016 -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device), 2017 ], 2018 dim=-1, 2019 ) 2020 target = torch.cat( 2021 [ 2022 copy(target), 2023 -100 * torch.ones((*target.shape[:-1], 1)).to(target.device), 2024 ], 2025 dim=-1, 2026 ) 2027 num_classes = predicted.shape[1] 2028 predicted = predicted.transpose(1, 2).reshape(-1, num_classes) 2029 if self.exclusive: 2030 target = target.flatten() 2031 else: 2032 target = target.transpose(1, 2).reshape(-1, num_classes) 2033 probability = copy(predicted) 2034 if not self.exclusive: 2035 predicted = (predicted > self.threshold).int() 2036 else: 2037 predicted = torch.max(predicted, 1)[1] 2038 for c in self.classes: 2039 if self.exclusive: 2040 predicted_intervals = self._get_intervals( 2041 predicted == c, probability=probability[:, c] 2042 ) 2043 target_intervals = self._get_intervals( 2044 target == c, probability=probability[:, c] 2045 ) 2046 else: 2047 predicted_intervals = self._get_intervals( 2048 predicted[:, c] == 1, probability=probability[:, c] 2049 ) 2050 target_intervals = self._get_intervals( 2051 target[:, c] == 1, probability=probability[:, c] 2052 ) 2053 cos_map, count_map, positive, confidence = self.match( 2054 predicted_intervals, self.iou_threshold, target_intervals 2055 ) 2056 cos_map = np.array(cos_map) 2057 cos_map[cos_map != -1] += len(self.count_map[c]) 2058 self.cos_map[c] += list(cos_map) 2059 self.count_map[c] += count_map 2060 self.confidence[c] += confidence 2061 self.positive[c] += positive
464class PR_AUC(_ClassificationMetric): 465 """Area under precision-recall curve (not advised for training).""" 466 467 def __init__( 468 self, 469 num_classes: int, 470 ignore_index: int = -100, 471 average: str = "macro", 472 ignored_classes: Set = None, 473 exclusive: bool = False, 474 tag_average: str = "micro", 475 threshold_step: float = 0.1, 476 ): 477 """Initialize the metric. 478 479 Parameters 480 ---------- 481 ignore_index : int, default -100 482 the class index that indicates ignored samples 483 average: {'macro', 'micro', 'none'} 484 method for averaging across classes 485 num_classes : int, optional 486 number of classes (not necessary if main_class is not None) 487 ignored_classes : set, optional 488 a set of class ids to ignore in calculation 489 exclusive: bool, default True 490 set to False for multi-label classification tasks 491 tag_average: {'micro', 'macro', 'none'} 492 method for averaging across meta tags (if given) 493 threshold_step : float, default 0.1 494 the decision threshold step 495 496 """ 497 if exclusive: 498 raise ValueError( 499 "The PR-AUC metric is not implemented for exclusive classification!" 500 ) 501 super().__init__( 502 num_classes, 503 ignore_index, 504 average, 505 ignored_classes, 506 exclusive, 507 tag_average=tag_average, 508 threshold_values=list(np.arange(0, 1, threshold_step)), 509 ) 510 511 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 512 """Calculate the metric value from true and false positive and negative rates.""" 513 precisions = [] 514 recalls = [] 515 for k in sorted(self.threshold_values): 516 precisions.append(tp[k] / (tp[k] + fp[k] + 1e-7)) 517 recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7)) 518 return metrics.auc(x=recalls, y=precisions)
Area under precision-recall curve (not advised for training).
467 def __init__( 468 self, 469 num_classes: int, 470 ignore_index: int = -100, 471 average: str = "macro", 472 ignored_classes: Set = None, 473 exclusive: bool = False, 474 tag_average: str = "micro", 475 threshold_step: float = 0.1, 476 ): 477 """Initialize the metric. 478 479 Parameters 480 ---------- 481 ignore_index : int, default -100 482 the class index that indicates ignored samples 483 average: {'macro', 'micro', 'none'} 484 method for averaging across classes 485 num_classes : int, optional 486 number of classes (not necessary if main_class is not None) 487 ignored_classes : set, optional 488 a set of class ids to ignore in calculation 489 exclusive: bool, default True 490 set to False for multi-label classification tasks 491 tag_average: {'micro', 'macro', 'none'} 492 method for averaging across meta tags (if given) 493 threshold_step : float, default 0.1 494 the decision threshold step 495 496 """ 497 if exclusive: 498 raise ValueError( 499 "The PR-AUC metric is not implemented for exclusive classification!" 500 ) 501 super().__init__( 502 num_classes, 503 ignore_index, 504 average, 505 ignored_classes, 506 exclusive, 507 tag_average=tag_average, 508 threshold_values=list(np.arange(0, 1, threshold_step)), 509 )
Initialize the metric.
Parameters
ignore_index : int, default -100 the class index that indicates ignored samples average: {'macro', 'micro', 'none'} method for averaging across classes num_classes : int, optional number of classes (not necessary if main_class is not None) ignored_classes : set, optional a set of class ids to ignore in calculation exclusive: bool, default True set to False for multi-label classification tasks tag_average: {'micro', 'macro', 'none'} method for averaging across meta tags (if given) threshold_step : float, default 0.1 the decision threshold step
521class Precision(_ClassificationMetric): 522 """Precision.""" 523 524 def __init__( 525 self, 526 num_classes: int, 527 ignore_index: int = -100, 528 average: str = "macro", 529 ignored_classes: Set = None, 530 exclusive: bool = True, 531 tag_average: str = "micro", 532 threshold_value: Union[float, List] = None, 533 ): 534 """Initialize the class. 535 536 Parameters 537 ---------- 538 ignore_index : int, default -100 539 the class index that indicates ignored samples 540 average: {'macro', 'micro', 'none'} 541 method for averaging across classes 542 num_classes : int, optional 543 number of classes (not necessary if main_class is not None) 544 ignored_classes : set, optional 545 a set of class ids to ignore in calculation 546 exclusive: bool, default True 547 set to False for multi-label classification tasks 548 tag_average: {'micro', 'macro', 'none'} 549 method for averaging across meta tags (if given) 550 threshold_value : float | list, optional 551 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 552 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 553 under the same index 554 555 """ 556 if threshold_value is None: 557 threshold_value = 0.5 558 super().__init__( 559 num_classes, 560 ignore_index, 561 average, 562 ignored_classes, 563 exclusive, 564 tag_average=tag_average, 565 threshold_values=[threshold_value], 566 ) 567 568 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 569 """Calculate the metric value from true and false positive and negative rates.""" 570 k = self.threshold_values[0] 571 if isinstance(k, list): 572 k = ", ".join(map(str, k)) 573 return tp[k] / (tp[k] + fp[k] + 1e-7)
Precision.
524 def __init__( 525 self, 526 num_classes: int, 527 ignore_index: int = -100, 528 average: str = "macro", 529 ignored_classes: Set = None, 530 exclusive: bool = True, 531 tag_average: str = "micro", 532 threshold_value: Union[float, List] = None, 533 ): 534 """Initialize the class. 535 536 Parameters 537 ---------- 538 ignore_index : int, default -100 539 the class index that indicates ignored samples 540 average: {'macro', 'micro', 'none'} 541 method for averaging across classes 542 num_classes : int, optional 543 number of classes (not necessary if main_class is not None) 544 ignored_classes : set, optional 545 a set of class ids to ignore in calculation 546 exclusive: bool, default True 547 set to False for multi-label classification tasks 548 tag_average: {'micro', 'macro', 'none'} 549 method for averaging across meta tags (if given) 550 threshold_value : float | list, optional 551 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 552 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 553 under the same index 554 555 """ 556 if threshold_value is None: 557 threshold_value = 0.5 558 super().__init__( 559 num_classes, 560 ignore_index, 561 average, 562 ignored_classes, 563 exclusive, 564 tag_average=tag_average, 565 threshold_values=[threshold_value], 566 )
Initialize the class.
Parameters
ignore_index : int, default -100
the class index that indicates ignored samples
average: {'macro', 'micro', 'none'}
method for averaging across classes
num_classes : int, optional
number of classes (not necessary if main_class is not None)
ignored_classes : set, optional
a set of class ids to ignore in calculation
exclusive: bool, default True
set to False for multi-label classification tasks
tag_average: {'micro', 'macro', 'none'}
method for averaging across meta tags (if given)
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
for non-exclusive); if threshold_value is a list, every value should correspond to the class
under the same index
576class SegmentalPrecision(_ClassificationMetric): 577 """Segmental precision (not advised for training).""" 578 579 segmental = True 580 581 def __init__( 582 self, 583 num_classes: int, 584 ignore_index: int = -100, 585 average: str = "macro", 586 ignored_classes: Set = None, 587 exclusive: bool = True, 588 iou_threshold: float = 0.5, 589 tag_average: str = "micro", 590 threshold_value: Union[float, List] = None, 591 ): 592 """Initialize the class. 593 594 Parameters 595 ---------- 596 ignore_index : int, default -100 597 the class index that indicates ignored samples 598 average: {'macro', 'micro', 'none'} 599 method for averaging across classes 600 num_classes : int, optional 601 number of classes (not necessary if main_class is not None) 602 ignored_classes : set, optional 603 a set of class ids to ignore in calculation 604 exclusive: bool, default True 605 set to False for multi-label classification tasks 606 iou_threshold : float, default 0.5 607 if segmental is true, intervals with IoU larger than this threshold are considered correct 608 tag_average: {'micro', 'macro', 'none'} 609 method for averaging across meta tags (if given) 610 threshold_value : float | list, optional 611 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 612 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 613 under the same index 614 615 """ 616 if threshold_value is None: 617 threshold_value = 0.5 618 super().__init__( 619 num_classes, 620 ignore_index, 621 average, 622 ignored_classes, 623 exclusive, 624 iou_threshold, 625 tag_average, 626 [threshold_value], 627 ) 628 629 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 630 """Calculate the metric value from true and false positive and negative rates.""" 631 k = self.threshold_values[0] 632 if isinstance(k, list): 633 k = ", ".join(map(str, k)) 634 return tp[k] / (tp[k] + fp[k] + 1e-7)
Segmental precision (not advised for training).
581 def __init__( 582 self, 583 num_classes: int, 584 ignore_index: int = -100, 585 average: str = "macro", 586 ignored_classes: Set = None, 587 exclusive: bool = True, 588 iou_threshold: float = 0.5, 589 tag_average: str = "micro", 590 threshold_value: Union[float, List] = None, 591 ): 592 """Initialize the class. 593 594 Parameters 595 ---------- 596 ignore_index : int, default -100 597 the class index that indicates ignored samples 598 average: {'macro', 'micro', 'none'} 599 method for averaging across classes 600 num_classes : int, optional 601 number of classes (not necessary if main_class is not None) 602 ignored_classes : set, optional 603 a set of class ids to ignore in calculation 604 exclusive: bool, default True 605 set to False for multi-label classification tasks 606 iou_threshold : float, default 0.5 607 if segmental is true, intervals with IoU larger than this threshold are considered correct 608 tag_average: {'micro', 'macro', 'none'} 609 method for averaging across meta tags (if given) 610 threshold_value : float | list, optional 611 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 612 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 613 under the same index 614 615 """ 616 if threshold_value is None: 617 threshold_value = 0.5 618 super().__init__( 619 num_classes, 620 ignore_index, 621 average, 622 ignored_classes, 623 exclusive, 624 iou_threshold, 625 tag_average, 626 [threshold_value], 627 )
Initialize the class.
Parameters
ignore_index : int, default -100
the class index that indicates ignored samples
average: {'macro', 'micro', 'none'}
method for averaging across classes
num_classes : int, optional
number of classes (not necessary if main_class is not None)
ignored_classes : set, optional
a set of class ids to ignore in calculation
exclusive: bool, default True
set to False for multi-label classification tasks
iou_threshold : float, default 0.5
if segmental is true, intervals with IoU larger than this threshold are considered correct
tag_average: {'micro', 'macro', 'none'}
method for averaging across meta tags (if given)
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
for non-exclusive); if threshold_value is a list, every value should correspond to the class
under the same index
637class Recall(_ClassificationMetric): 638 """Recall.""" 639 640 def __init__( 641 self, 642 num_classes: int, 643 ignore_index: int = -100, 644 average: str = "macro", 645 ignored_classes: Set = None, 646 exclusive: bool = True, 647 tag_average: str = "micro", 648 threshold_value: Union[float, List] = None, 649 ): 650 """Initialize the class. 651 652 Parameters 653 ---------- 654 ignore_index : int, default -100 655 the class index that indicates ignored samples 656 average: {'macro', 'micro', 'none'} 657 method for averaging across classes 658 num_classes : int, optional 659 number of classes (not necessary if main_class is not None) 660 ignored_classes : set, optional 661 a set of class ids to ignore in calculation 662 exclusive: bool, default True 663 set to False for multi-label classification tasks 664 tag_average: {'micro', 'macro', 'none'} 665 method for averaging across meta tags (if given) 666 threshold_value : float | list, optional 667 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 668 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 669 under the same index 670 671 """ 672 if threshold_value is None: 673 threshold_value = 0.5 674 super().__init__( 675 num_classes, 676 ignore_index, 677 average, 678 ignored_classes, 679 exclusive, 680 tag_average=tag_average, 681 threshold_values=[threshold_value], 682 ) 683 684 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 685 """Calculate the metric value from true and false positive and negative rates.""" 686 k = self.threshold_values[0] 687 if isinstance(k, list): 688 k = ", ".join(map(str, k)) 689 return tp[k] / (tp[k] + fn[k] + 1e-7)
Recall.
640 def __init__( 641 self, 642 num_classes: int, 643 ignore_index: int = -100, 644 average: str = "macro", 645 ignored_classes: Set = None, 646 exclusive: bool = True, 647 tag_average: str = "micro", 648 threshold_value: Union[float, List] = None, 649 ): 650 """Initialize the class. 651 652 Parameters 653 ---------- 654 ignore_index : int, default -100 655 the class index that indicates ignored samples 656 average: {'macro', 'micro', 'none'} 657 method for averaging across classes 658 num_classes : int, optional 659 number of classes (not necessary if main_class is not None) 660 ignored_classes : set, optional 661 a set of class ids to ignore in calculation 662 exclusive: bool, default True 663 set to False for multi-label classification tasks 664 tag_average: {'micro', 'macro', 'none'} 665 method for averaging across meta tags (if given) 666 threshold_value : float | list, optional 667 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 668 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 669 under the same index 670 671 """ 672 if threshold_value is None: 673 threshold_value = 0.5 674 super().__init__( 675 num_classes, 676 ignore_index, 677 average, 678 ignored_classes, 679 exclusive, 680 tag_average=tag_average, 681 threshold_values=[threshold_value], 682 )
Initialize the class.
Parameters
ignore_index : int, default -100
the class index that indicates ignored samples
average: {'macro', 'micro', 'none'}
method for averaging across classes
num_classes : int, optional
number of classes (not necessary if main_class is not None)
ignored_classes : set, optional
a set of class ids to ignore in calculation
exclusive: bool, default True
set to False for multi-label classification tasks
tag_average: {'micro', 'macro', 'none'}
method for averaging across meta tags (if given)
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
for non-exclusive); if threshold_value is a list, every value should correspond to the class
under the same index
692class SegmentalRecall(_ClassificationMetric): 693 """Segmental recall (not advised for training).""" 694 695 segmental = True 696 697 def __init__( 698 self, 699 num_classes: int, 700 ignore_index: int = -100, 701 average: str = "macro", 702 ignored_classes: Set = None, 703 exclusive: bool = True, 704 iou_threshold: float = 0.5, 705 tag_average: str = "micro", 706 threshold_value: Union[float, List] = None, 707 ): 708 """Initialize the class. 709 710 Parameters 711 ---------- 712 ignore_index : int, default -100 713 the class index that indicates ignored samples 714 average: {'macro', 'micro', 'none'} 715 method for averaging across classes 716 num_classes : int, optional 717 number of classes (not necessary if main_class is not None) 718 ignored_classes : set, optional 719 a set of class ids to ignore in calculation 720 exclusive: bool, default True 721 set to False for multi-label classification tasks 722 iou_threshold : float, default 0.5 723 if segmental is true, intervals with IoU larger than this threshold are considered correct 724 tag_average: {'micro', 'macro', 'none'} 725 method for averaging across meta tags (if given) 726 threshold_value : float | list, optional 727 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 728 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 729 under the same index 730 731 """ 732 if threshold_value is None: 733 threshold_value = 0.5 734 super().__init__( 735 num_classes, 736 ignore_index, 737 average, 738 ignored_classes, 739 exclusive, 740 iou_threshold, 741 tag_average, 742 [threshold_value], 743 ) 744 745 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 746 """Calculate the metric value from true and false positive and negative rates.""" 747 k = self.threshold_values[0] 748 if isinstance(k, list): 749 k = ", ".join(map(str, k)) 750 return tp[k] / (tp[k] + fn[k] + 1e-7)
Segmental recall (not advised for training).
697 def __init__( 698 self, 699 num_classes: int, 700 ignore_index: int = -100, 701 average: str = "macro", 702 ignored_classes: Set = None, 703 exclusive: bool = True, 704 iou_threshold: float = 0.5, 705 tag_average: str = "micro", 706 threshold_value: Union[float, List] = None, 707 ): 708 """Initialize the class. 709 710 Parameters 711 ---------- 712 ignore_index : int, default -100 713 the class index that indicates ignored samples 714 average: {'macro', 'micro', 'none'} 715 method for averaging across classes 716 num_classes : int, optional 717 number of classes (not necessary if main_class is not None) 718 ignored_classes : set, optional 719 a set of class ids to ignore in calculation 720 exclusive: bool, default True 721 set to False for multi-label classification tasks 722 iou_threshold : float, default 0.5 723 if segmental is true, intervals with IoU larger than this threshold are considered correct 724 tag_average: {'micro', 'macro', 'none'} 725 method for averaging across meta tags (if given) 726 threshold_value : float | list, optional 727 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 728 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 729 under the same index 730 731 """ 732 if threshold_value is None: 733 threshold_value = 0.5 734 super().__init__( 735 num_classes, 736 ignore_index, 737 average, 738 ignored_classes, 739 exclusive, 740 iou_threshold, 741 tag_average, 742 [threshold_value], 743 )
Initialize the class.
Parameters
ignore_index : int, default -100
the class index that indicates ignored samples
average: {'macro', 'micro', 'none'}
method for averaging across classes
num_classes : int, optional
number of classes (not necessary if main_class is not None)
ignored_classes : set, optional
a set of class ids to ignore in calculation
exclusive: bool, default True
set to False for multi-label classification tasks
iou_threshold : float, default 0.5
if segmental is true, intervals with IoU larger than this threshold are considered correct
tag_average: {'micro', 'macro', 'none'}
method for averaging across meta tags (if given)
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
for non-exclusive); if threshold_value is a list, every value should correspond to the class
under the same index
753class F1(_ClassificationMetric): 754 """F1 score.""" 755 756 def __init__( 757 self, 758 num_classes: int, 759 ignore_index: int = -100, 760 average: str = "macro", 761 ignored_classes: Set = None, 762 exclusive: bool = True, 763 tag_average: str = "micro", 764 threshold_value: Union[float, List] = None, 765 integration_interval: int = 0, 766 ): 767 """Initialize the class. 768 769 Parameters 770 ---------- 771 ignore_index : int, default -100 772 the class index that indicates ignored samples 773 average: {'macro', 'micro', 'none'} 774 method for averaging across classes 775 num_classes : int, optional 776 number of classes (not necessary if main_class is not None) 777 ignored_classes : set, optional 778 a set of class ids to ignore in calculation 779 exclusive: bool, default True 780 set to False for multi-label classification tasks 781 tag_average: {'micro', 'macro', 'none'} 782 method for averaging across meta tags (if given) 783 threshold_value : float | list, optional 784 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 785 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 786 under the same index 787 integration_interval : int, default 0 788 the number of frames to integrate over (0 means no integration) 789 790 """ 791 if threshold_value is None: 792 threshold_value = 0.5 793 super().__init__( 794 num_classes, 795 ignore_index, 796 average, 797 ignored_classes, 798 exclusive, 799 tag_average=tag_average, 800 threshold_values=[threshold_value], 801 integration_interval=integration_interval, 802 ) 803 804 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 805 """Calculate the metric value from true and false positive and negative rates.""" 806 if self.optimize: 807 scores = [] 808 for k in self.threshold_values: 809 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 810 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 811 scores.append(2 * recall * precision / (recall + precision + 1e-7)) 812 f1 = max(scores) 813 else: 814 k = self.threshold_values[0] 815 if isinstance(k, list): 816 k = ", ".join(map(str, k)) 817 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 818 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 819 f1 = 2 * recall * precision / (recall + precision + 1e-7) 820 return f1
F1 score.
756 def __init__( 757 self, 758 num_classes: int, 759 ignore_index: int = -100, 760 average: str = "macro", 761 ignored_classes: Set = None, 762 exclusive: bool = True, 763 tag_average: str = "micro", 764 threshold_value: Union[float, List] = None, 765 integration_interval: int = 0, 766 ): 767 """Initialize the class. 768 769 Parameters 770 ---------- 771 ignore_index : int, default -100 772 the class index that indicates ignored samples 773 average: {'macro', 'micro', 'none'} 774 method for averaging across classes 775 num_classes : int, optional 776 number of classes (not necessary if main_class is not None) 777 ignored_classes : set, optional 778 a set of class ids to ignore in calculation 779 exclusive: bool, default True 780 set to False for multi-label classification tasks 781 tag_average: {'micro', 'macro', 'none'} 782 method for averaging across meta tags (if given) 783 threshold_value : float | list, optional 784 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 785 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 786 under the same index 787 integration_interval : int, default 0 788 the number of frames to integrate over (0 means no integration) 789 790 """ 791 if threshold_value is None: 792 threshold_value = 0.5 793 super().__init__( 794 num_classes, 795 ignore_index, 796 average, 797 ignored_classes, 798 exclusive, 799 tag_average=tag_average, 800 threshold_values=[threshold_value], 801 integration_interval=integration_interval, 802 )
Initialize the class.
Parameters
ignore_index : int, default -100
the class index that indicates ignored samples
average: {'macro', 'micro', 'none'}
method for averaging across classes
num_classes : int, optional
number of classes (not necessary if main_class is not None)
ignored_classes : set, optional
a set of class ids to ignore in calculation
exclusive: bool, default True
set to False for multi-label classification tasks
tag_average: {'micro', 'macro', 'none'}
method for averaging across meta tags (if given)
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
for non-exclusive); if threshold_value is a list, every value should correspond to the class
under the same index
integration_interval : int, default 0
the number of frames to integrate over (0 means no integration)
823class SegmentalF1(_ClassificationMetric): 824 """Segmental F1 score (not advised for training).""" 825 826 segmental = True 827 828 def __init__( 829 self, 830 num_classes: int, 831 ignore_index: int = -100, 832 average: str = "macro", 833 ignored_classes: Set = None, 834 exclusive: bool = True, 835 iou_threshold: float = 0.5, 836 tag_average: str = "micro", 837 threshold_value: Union[float, List] = None, 838 ): 839 """Initialize the class. 840 841 Parameters 842 ---------- 843 ignore_index : int, default -100 844 the class index that indicates ignored samples 845 average: {'macro', 'micro', 'none'} 846 method for averaging across classes 847 num_classes : int, optional 848 number of classes (not necessary if main_class is not None) 849 ignored_classes : set, optional 850 a set of class ids to ignore in calculation 851 exclusive: bool, default True 852 set to False for multi-label classification tasks 853 iou_threshold : float, default 0.5 854 if segmental is true, intervals with IoU larger than this threshold are considered correct 855 tag_average: {'micro', 'macro', 'none'} 856 method for averaging across meta tags (if given) 857 threshold_value : float | list, optional 858 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 859 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 860 under the same index 861 862 """ 863 if threshold_value is None: 864 threshold_value = 0.5 865 super().__init__( 866 num_classes, 867 ignore_index, 868 average, 869 ignored_classes, 870 exclusive, 871 iou_threshold, 872 tag_average, 873 [threshold_value], 874 ) 875 876 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 877 """Calculate the metric value from true and false positive and negative rates.""" 878 if self.optimize: 879 scores = [] 880 for k in self.threshold_values: 881 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 882 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 883 scores.append(2 * recall * precision / (recall + precision + 1e-7)) 884 f1 = max(scores) 885 else: 886 k = self.threshold_values[0] 887 if isinstance(k, list): 888 k = ", ".join(map(str, k)) 889 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 890 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 891 f1 = 2 * recall * precision / (recall + precision + 1e-7) 892 return f1
Segmental F1 score (not advised for training).
828 def __init__( 829 self, 830 num_classes: int, 831 ignore_index: int = -100, 832 average: str = "macro", 833 ignored_classes: Set = None, 834 exclusive: bool = True, 835 iou_threshold: float = 0.5, 836 tag_average: str = "micro", 837 threshold_value: Union[float, List] = None, 838 ): 839 """Initialize the class. 840 841 Parameters 842 ---------- 843 ignore_index : int, default -100 844 the class index that indicates ignored samples 845 average: {'macro', 'micro', 'none'} 846 method for averaging across classes 847 num_classes : int, optional 848 number of classes (not necessary if main_class is not None) 849 ignored_classes : set, optional 850 a set of class ids to ignore in calculation 851 exclusive: bool, default True 852 set to False for multi-label classification tasks 853 iou_threshold : float, default 0.5 854 if segmental is true, intervals with IoU larger than this threshold are considered correct 855 tag_average: {'micro', 'macro', 'none'} 856 method for averaging across meta tags (if given) 857 threshold_value : float | list, optional 858 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 859 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 860 under the same index 861 862 """ 863 if threshold_value is None: 864 threshold_value = 0.5 865 super().__init__( 866 num_classes, 867 ignore_index, 868 average, 869 ignored_classes, 870 exclusive, 871 iou_threshold, 872 tag_average, 873 [threshold_value], 874 )
Initialize the class.
Parameters
ignore_index : int, default -100
the class index that indicates ignored samples
average: {'macro', 'micro', 'none'}
method for averaging across classes
num_classes : int, optional
number of classes (not necessary if main_class is not None)
ignored_classes : set, optional
a set of class ids to ignore in calculation
exclusive: bool, default True
set to False for multi-label classification tasks
iou_threshold : float, default 0.5
if segmental is true, intervals with IoU larger than this threshold are considered correct
tag_average: {'micro', 'macro', 'none'}
method for averaging across meta tags (if given)
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
for non-exclusive); if threshold_value is a list, every value should correspond to the class
under the same index
895class Fbeta(_ClassificationMetric): 896 """F-beta score.""" 897 898 def __init__( 899 self, 900 beta: float = 1, 901 ignore_index: int = -100, 902 average: str = "macro", 903 num_classes: int = None, 904 ignored_classes: Set = None, 905 tag_average: str = "micro", 906 exclusive: bool = True, 907 threshold_value: float = 0.5, 908 ): 909 """Initialize the class. 910 911 Parameters 912 ---------- 913 beta : float, default 1 914 the beta parameter 915 ignore_index : int, default -100 916 the class index that indicates ignored samples 917 average: {'macro', 'micro', 'none'} 918 method for averaging across classes 919 num_classes : int, optional 920 number of classes (not necessary if main_class is not None) 921 ignored_classes : set, optional 922 a set of class ids to ignore in calculation 923 exclusive: bool, default True 924 set to False for multi-label classification tasks 925 tag_average: {'micro', 'macro', 'none'} 926 method for averaging across meta tags (if given) 927 threshold_value : float | list, optional 928 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 929 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 930 under the same index 931 932 """ 933 if threshold_value is None: 934 threshold_value = 0.5 935 self.beta2 = beta**2 936 super().__init__( 937 num_classes, 938 ignore_index, 939 average, 940 ignored_classes, 941 exclusive, 942 tag_average=tag_average, 943 threshold_values=[threshold_value], 944 ) 945 946 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 947 """Calculate the metric value from true and false positive and negative rates.""" 948 if self.optimize: 949 scores = [] 950 for k in self.threshold_values: 951 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 952 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 953 scores.append( 954 ( 955 (1 + self.beta2) 956 * precision 957 * recall 958 / (self.beta2 * precision + recall + 1e-7) 959 ) 960 ) 961 f1 = max(scores) 962 else: 963 k = self.threshold_values[0] 964 if isinstance(k, list): 965 k = ", ".join(map(str, k)) 966 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 967 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 968 f1 = ( 969 (1 + self.beta2) 970 * precision 971 * recall 972 / (self.beta2 * precision + recall + 1e-7) 973 ) 974 return f1
F-beta score.
898 def __init__( 899 self, 900 beta: float = 1, 901 ignore_index: int = -100, 902 average: str = "macro", 903 num_classes: int = None, 904 ignored_classes: Set = None, 905 tag_average: str = "micro", 906 exclusive: bool = True, 907 threshold_value: float = 0.5, 908 ): 909 """Initialize the class. 910 911 Parameters 912 ---------- 913 beta : float, default 1 914 the beta parameter 915 ignore_index : int, default -100 916 the class index that indicates ignored samples 917 average: {'macro', 'micro', 'none'} 918 method for averaging across classes 919 num_classes : int, optional 920 number of classes (not necessary if main_class is not None) 921 ignored_classes : set, optional 922 a set of class ids to ignore in calculation 923 exclusive: bool, default True 924 set to False for multi-label classification tasks 925 tag_average: {'micro', 'macro', 'none'} 926 method for averaging across meta tags (if given) 927 threshold_value : float | list, optional 928 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 929 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 930 under the same index 931 932 """ 933 if threshold_value is None: 934 threshold_value = 0.5 935 self.beta2 = beta**2 936 super().__init__( 937 num_classes, 938 ignore_index, 939 average, 940 ignored_classes, 941 exclusive, 942 tag_average=tag_average, 943 threshold_values=[threshold_value], 944 )
Initialize the class.
Parameters
beta : float, default 1
the beta parameter
ignore_index : int, default -100
the class index that indicates ignored samples
average: {'macro', 'micro', 'none'}
method for averaging across classes
num_classes : int, optional
number of classes (not necessary if main_class is not None)
ignored_classes : set, optional
a set of class ids to ignore in calculation
exclusive: bool, default True
set to False for multi-label classification tasks
tag_average: {'micro', 'macro', 'none'}
method for averaging across meta tags (if given)
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
for non-exclusive); if threshold_value is a list, every value should correspond to the class
under the same index
977class SegmentalFbeta(_ClassificationMetric): 978 """Segmental F-beta score (not advised for training).""" 979 980 segmental = True 981 982 def __init__( 983 self, 984 beta: float = 1, 985 ignore_index: int = -100, 986 average: str = "macro", 987 num_classes: int = None, 988 ignored_classes: Set = None, 989 iou_threshold: float = 0.5, 990 tag_average: str = "micro", 991 exclusive: bool = True, 992 threshold_value: float = 0.5, 993 ): 994 """Initialize the class. 995 996 Parameters 997 ---------- 998 beta : float, default 1 999 the beta parameter 1000 ignore_index : int, default -100 1001 the class index that indicates ignored samples 1002 average: {'macro', 'micro', 'none'} 1003 method for averaging across classes 1004 num_classes : int, optional 1005 number of classes (not necessary if main_class is not None) 1006 ignored_classes : set, optional 1007 a set of class ids to ignore in calculation 1008 exclusive: bool, default True 1009 set to False for multi-label classification tasks 1010 iou_threshold : float, default 0.5 1011 if segmental is true, intervals with IoU larger than this threshold are considered correct 1012 tag_average: {'micro', 'macro', 'none'} 1013 method for averaging across meta tags (if given) 1014 threshold_value : float | list, optional 1015 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 1016 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 1017 under the same index 1018 1019 """ 1020 if threshold_value is None: 1021 threshold_value = 0.5 1022 self.beta2 = beta**2 1023 super().__init__( 1024 num_classes, 1025 ignore_index, 1026 average, 1027 ignored_classes, 1028 exclusive, 1029 iou_threshold=iou_threshold, 1030 tag_average=tag_average, 1031 threshold_values=[threshold_value], 1032 ) 1033 1034 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1035 """Calculate the metric value from true and false positive and negative rates.""" 1036 if self.optimize: 1037 scores = [] 1038 for k in self.threshold_values: 1039 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 1040 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 1041 scores.append( 1042 ( 1043 (1 + self.beta2) 1044 * precision 1045 * recall 1046 / (self.beta2 * precision + recall + 1e-7) 1047 ) 1048 ) 1049 f1 = max(scores) 1050 else: 1051 k = self.threshold_values[0] 1052 if isinstance(k, list): 1053 k = ", ".join(map(str, k)) 1054 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 1055 precision = tp[k] / (tp[k] + fp[k] + 1e-7) 1056 f1 = ( 1057 (1 + self.beta2) 1058 * precision 1059 * recall 1060 / (self.beta2 * precision + recall + 1e-7) 1061 ) 1062 return f1
Segmental F-beta score (not advised for training).
982 def __init__( 983 self, 984 beta: float = 1, 985 ignore_index: int = -100, 986 average: str = "macro", 987 num_classes: int = None, 988 ignored_classes: Set = None, 989 iou_threshold: float = 0.5, 990 tag_average: str = "micro", 991 exclusive: bool = True, 992 threshold_value: float = 0.5, 993 ): 994 """Initialize the class. 995 996 Parameters 997 ---------- 998 beta : float, default 1 999 the beta parameter 1000 ignore_index : int, default -100 1001 the class index that indicates ignored samples 1002 average: {'macro', 'micro', 'none'} 1003 method for averaging across classes 1004 num_classes : int, optional 1005 number of classes (not necessary if main_class is not None) 1006 ignored_classes : set, optional 1007 a set of class ids to ignore in calculation 1008 exclusive: bool, default True 1009 set to False for multi-label classification tasks 1010 iou_threshold : float, default 0.5 1011 if segmental is true, intervals with IoU larger than this threshold are considered correct 1012 tag_average: {'micro', 'macro', 'none'} 1013 method for averaging across meta tags (if given) 1014 threshold_value : float | list, optional 1015 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default 1016 for non-exclusive); if `threshold_value` is a list, every value should correspond to the class 1017 under the same index 1018 1019 """ 1020 if threshold_value is None: 1021 threshold_value = 0.5 1022 self.beta2 = beta**2 1023 super().__init__( 1024 num_classes, 1025 ignore_index, 1026 average, 1027 ignored_classes, 1028 exclusive, 1029 iou_threshold=iou_threshold, 1030 tag_average=tag_average, 1031 threshold_values=[threshold_value], 1032 )
Initialize the class.
Parameters
beta : float, default 1
the beta parameter
ignore_index : int, default -100
the class index that indicates ignored samples
average: {'macro', 'micro', 'none'}
method for averaging across classes
num_classes : int, optional
number of classes (not necessary if main_class is not None)
ignored_classes : set, optional
a set of class ids to ignore in calculation
exclusive: bool, default True
set to False for multi-label classification tasks
iou_threshold : float, default 0.5
if segmental is true, intervals with IoU larger than this threshold are considered correct
tag_average: {'micro', 'macro', 'none'}
method for averaging across meta tags (if given)
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default
for non-exclusive); if threshold_value is a list, every value should correspond to the class
under the same index
1317class SemiSegmentalRecall(_SemiSegmentalMetric): 1318 """Semi-segmental recall (not advised for training). 1319 1320 A metric in-between segmental and frame-wise recall. 1321 1322 This metric follows the following algorithm: 1323 1) smooth over too-short intervals, both in ground truth and in prediction (first remove 1324 groups of zeros shorter than `smooth_interval` and then do the same with groups of ones), 1325 2) add `delta` frames to each ground truth interval at both ends and count the number of predicted 1326 positive frames at the resulting intervals (intersection), 1327 3) calculate the threshold for each interval as 1328 `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where 1329 `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval 1330 before `delta` was added, 1331 4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`), 1332 and otherwise as false negative (`FN`), 1333 5) the final metric value is computed as `TP / (TP + FN)`. 1334 """ 1335 1336 def __init__( 1337 self, 1338 num_classes: int, 1339 ignore_index: int = -100, 1340 ignored_classes: Set = None, 1341 exclusive: bool = True, 1342 average: str = "macro", 1343 tag_average: str = "micro", 1344 delta: int = 0, 1345 smooth_interval: int = 0, 1346 iou_threshold_long: float = 0.5, 1347 iou_threshold_short: float = 0.5, 1348 short_length: int = 30, 1349 long_length: int = 300, 1350 threshold_value: Union[float, List] = None, 1351 ) -> None: 1352 """Initialize the class. 1353 1354 Parameters 1355 ---------- 1356 num_classes : int 1357 the number of classes in the dataset 1358 ignore_index : int, default -100 1359 the ground truth label to ignore 1360 ignored_classes : set, optional 1361 the class indices to ignore in computation 1362 exclusive : bool, default True 1363 `False` for multi-label classification tasks 1364 average : {"macro", "micro", "none"} 1365 the method to average the results over classes 1366 tag_average : {"macro", "micro", "none"} 1367 the method to average the results over meta tags (if given) 1368 delta : int, default 0 1369 the number of frames to add to each ground truth interval before computing the intersection, 1370 see description of the class for details 1371 smooth_interval : int, default 0 1372 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1373 see description of the class for details 1374 iou_threshold_long : float, default 0.5 1375 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1376 see description of the class for details 1377 iou_threshold_short : float, default 0.5 1378 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1379 see description of the class for details 1380 short_length : int, default 30 1381 the threshold number of frames for short intervals that will have an intersection threshold of 1382 `iou_threshold_short`, see description of the class for details 1383 long_length : int, default 300 1384 the threshold number of frames for long intervals that will have an intersection threshold of 1385 `iou_threshold_long`, see description of the class for details 1386 threshold_value : float | list, optional 1387 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1388 1389 """ 1390 super().__init__( 1391 num_classes, 1392 ignore_index, 1393 ignored_classes, 1394 exclusive, 1395 average, 1396 tag_average, 1397 delta, 1398 smooth_interval, 1399 iou_threshold_long, 1400 iou_threshold_short, 1401 short_length, 1402 long_length, 1403 [threshold_value], 1404 ) 1405 1406 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1407 """Calculate the metric value from true and false positive and negative rates.""" 1408 k = self.threshold_values[0] 1409 if isinstance(k, list): 1410 k = ", ".join(map(str, k)) 1411 return tp[k] / (tp[k] + fn[k] + 1e-7)
Semi-segmental recall (not advised for training).
A metric in-between segmental and frame-wise recall.
This metric follows the following algorithm:
1) smooth over too-short intervals, both in ground truth and in prediction (first remove
groups of zeros shorter than smooth_interval and then do the same with groups of ones),
2) add delta frames to each ground truth interval at both ends and count the number of predicted
positive frames at the resulting intervals (intersection),
3) calculate the threshold for each interval as
t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short)), where
a = 2 / (long_length - short_length), b = 1 - a * long_length, x is the length of the interval
before delta was added,
4) for each interval, if intersection is higher than t * x, the interval is labeled as true positive (TP),
and otherwise as false negative (FN),
5) the final metric value is computed as TP / (TP + FN).
1336 def __init__( 1337 self, 1338 num_classes: int, 1339 ignore_index: int = -100, 1340 ignored_classes: Set = None, 1341 exclusive: bool = True, 1342 average: str = "macro", 1343 tag_average: str = "micro", 1344 delta: int = 0, 1345 smooth_interval: int = 0, 1346 iou_threshold_long: float = 0.5, 1347 iou_threshold_short: float = 0.5, 1348 short_length: int = 30, 1349 long_length: int = 300, 1350 threshold_value: Union[float, List] = None, 1351 ) -> None: 1352 """Initialize the class. 1353 1354 Parameters 1355 ---------- 1356 num_classes : int 1357 the number of classes in the dataset 1358 ignore_index : int, default -100 1359 the ground truth label to ignore 1360 ignored_classes : set, optional 1361 the class indices to ignore in computation 1362 exclusive : bool, default True 1363 `False` for multi-label classification tasks 1364 average : {"macro", "micro", "none"} 1365 the method to average the results over classes 1366 tag_average : {"macro", "micro", "none"} 1367 the method to average the results over meta tags (if given) 1368 delta : int, default 0 1369 the number of frames to add to each ground truth interval before computing the intersection, 1370 see description of the class for details 1371 smooth_interval : int, default 0 1372 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1373 see description of the class for details 1374 iou_threshold_long : float, default 0.5 1375 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1376 see description of the class for details 1377 iou_threshold_short : float, default 0.5 1378 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1379 see description of the class for details 1380 short_length : int, default 30 1381 the threshold number of frames for short intervals that will have an intersection threshold of 1382 `iou_threshold_short`, see description of the class for details 1383 long_length : int, default 300 1384 the threshold number of frames for long intervals that will have an intersection threshold of 1385 `iou_threshold_long`, see description of the class for details 1386 threshold_value : float | list, optional 1387 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1388 1389 """ 1390 super().__init__( 1391 num_classes, 1392 ignore_index, 1393 ignored_classes, 1394 exclusive, 1395 average, 1396 tag_average, 1397 delta, 1398 smooth_interval, 1399 iou_threshold_long, 1400 iou_threshold_short, 1401 short_length, 1402 long_length, 1403 [threshold_value], 1404 )
Initialize the class.
Parameters
num_classes : int
the number of classes in the dataset
ignore_index : int, default -100
the ground truth label to ignore
ignored_classes : set, optional
the class indices to ignore in computation
exclusive : bool, default True
False for multi-label classification tasks
average : {"macro", "micro", "none"}
the method to average the results over classes
tag_average : {"macro", "micro", "none"}
the method to average the results over meta tags (if given)
delta : int, default 0
the number of frames to add to each ground truth interval before computing the intersection,
see description of the class for details
smooth_interval : int, default 0
intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
see description of the class for details
iou_threshold_long : float, default 0.5
the intersection threshold for segments longer than long_length frames (between 0 and 1),
see description of the class for details
iou_threshold_short : float, default 0.5
the intersection threshold for segments shorter than short_length frames (between 0 and 1),
see description of the class for details
short_length : int, default 30
the threshold number of frames for short intervals that will have an intersection threshold of
iou_threshold_short, see description of the class for details
long_length : int, default 300
the threshold number of frames for long intervals that will have an intersection threshold of
iou_threshold_long, see description of the class for details
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1414class SemiSegmentalPrecision(_SemiSegmentalMetric): 1415 """Semi-segmental precision (not advised for training). 1416 1417 A metric in-between segmental and frame-wise precision. 1418 1419 This metric follows the following algorithm: 1420 1) smooth over too-short intervals, both in ground truth and in prediction (first remove 1421 groups of zeros shorter than `smooth_interval` and then do the same with groups of ones), 1422 2) add `delta` frames to each predicted interval at both ends and count the number of ground truth 1423 positive frames at the resulting intervals (intersection), 1424 3) calculate the threshold for each interval as 1425 `t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short))`, where 1426 `a = 2 / (long_length - short_length)`, `b = 1 - a * long_length`, `x` is the length of the interval 1427 before `delta` was added, 1428 4) for each interval, if intersection is higher than `t * x`, the interval is labeled as true positive (`TP`), 1429 and otherwise as false positive (`FP`), 1430 5) the final metric value is computed as `TP / (TP + FP)`. 1431 1432 """ 1433 def __init__( 1434 self, 1435 num_classes: int, 1436 ignore_index: int = -100, 1437 ignored_classes: Set = None, 1438 exclusive: bool = True, 1439 average: str = "macro", 1440 tag_average: str = "micro", 1441 delta: int = 0, 1442 smooth_interval: int = 0, 1443 iou_threshold_long: float = 0.5, 1444 iou_threshold_short: float = 0.5, 1445 short_length: int = 30, 1446 long_length: int = 300, 1447 threshold_value: Union[float, List] = None, 1448 ) -> None: 1449 """Initialize the class. 1450 1451 Parameters 1452 ---------- 1453 num_classes : int 1454 the number of classes in the dataset 1455 ignore_index : int, default -100 1456 the ground truth label to ignore 1457 ignored_classes : set, optional 1458 the class indices to ignore in computation 1459 exclusive : bool, default True 1460 `False` for multi-label classification tasks 1461 average : {"macro", "micro", "none"} 1462 the method to average the results over classes 1463 tag_average : {"macro", "micro", "none"} 1464 the method to average the results over meta tags (if given) 1465 delta : int, default 0 1466 the number of frames to add to each ground truth interval before computing the intersection, 1467 see description of the class for details 1468 smooth_interval : int, default 0 1469 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1470 see description of the class for details 1471 iou_threshold_long : float, default 0.5 1472 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1473 see description of the class for details 1474 iou_threshold_short : float, default 0.5 1475 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1476 see description of the class for details 1477 short_length : int, default 30 1478 the threshold number of frames for short intervals that will have an intersection threshold of 1479 `iou_threshold_short`, see description of the class for details 1480 long_length : int, default 300 1481 the threshold number of frames for long intervals that will have an intersection threshold of 1482 `iou_threshold_long`, see description of the class for details 1483 threshold_value : float | list, optional 1484 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1485 1486 """ 1487 super().__init__( 1488 num_classes, 1489 ignore_index, 1490 ignored_classes, 1491 exclusive, 1492 average, 1493 tag_average, 1494 delta, 1495 smooth_interval, 1496 iou_threshold_long, 1497 iou_threshold_short, 1498 short_length, 1499 long_length, 1500 [threshold_value], 1501 ) 1502 1503 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1504 """Calculate the metric value from true and false positive and negative rates.""" 1505 k = self.threshold_values[0] 1506 if isinstance(k, list): 1507 k = ", ".join(map(str, k)) 1508 return tn[k] / (tn[k] + fp[k] + 1e-7)
Semi-segmental precision (not advised for training).
A metric in-between segmental and frame-wise precision.
This metric follows the following algorithm:
1) smooth over too-short intervals, both in ground truth and in prediction (first remove
groups of zeros shorter than smooth_interval and then do the same with groups of ones),
2) add delta frames to each predicted interval at both ends and count the number of ground truth
positive frames at the resulting intervals (intersection),
3) calculate the threshold for each interval as
t = sigmoid(4 * (a * x + b)) * (iou_threshold_long - iou_threshold_short)), where
a = 2 / (long_length - short_length), b = 1 - a * long_length, x is the length of the interval
before delta was added,
4) for each interval, if intersection is higher than t * x, the interval is labeled as true positive (TP),
and otherwise as false positive (FP),
5) the final metric value is computed as TP / (TP + FP).
1433 def __init__( 1434 self, 1435 num_classes: int, 1436 ignore_index: int = -100, 1437 ignored_classes: Set = None, 1438 exclusive: bool = True, 1439 average: str = "macro", 1440 tag_average: str = "micro", 1441 delta: int = 0, 1442 smooth_interval: int = 0, 1443 iou_threshold_long: float = 0.5, 1444 iou_threshold_short: float = 0.5, 1445 short_length: int = 30, 1446 long_length: int = 300, 1447 threshold_value: Union[float, List] = None, 1448 ) -> None: 1449 """Initialize the class. 1450 1451 Parameters 1452 ---------- 1453 num_classes : int 1454 the number of classes in the dataset 1455 ignore_index : int, default -100 1456 the ground truth label to ignore 1457 ignored_classes : set, optional 1458 the class indices to ignore in computation 1459 exclusive : bool, default True 1460 `False` for multi-label classification tasks 1461 average : {"macro", "micro", "none"} 1462 the method to average the results over classes 1463 tag_average : {"macro", "micro", "none"} 1464 the method to average the results over meta tags (if given) 1465 delta : int, default 0 1466 the number of frames to add to each ground truth interval before computing the intersection, 1467 see description of the class for details 1468 smooth_interval : int, default 0 1469 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1470 see description of the class for details 1471 iou_threshold_long : float, default 0.5 1472 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1473 see description of the class for details 1474 iou_threshold_short : float, default 0.5 1475 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1476 see description of the class for details 1477 short_length : int, default 30 1478 the threshold number of frames for short intervals that will have an intersection threshold of 1479 `iou_threshold_short`, see description of the class for details 1480 long_length : int, default 300 1481 the threshold number of frames for long intervals that will have an intersection threshold of 1482 `iou_threshold_long`, see description of the class for details 1483 threshold_value : float | list, optional 1484 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1485 1486 """ 1487 super().__init__( 1488 num_classes, 1489 ignore_index, 1490 ignored_classes, 1491 exclusive, 1492 average, 1493 tag_average, 1494 delta, 1495 smooth_interval, 1496 iou_threshold_long, 1497 iou_threshold_short, 1498 short_length, 1499 long_length, 1500 [threshold_value], 1501 )
Initialize the class.
Parameters
num_classes : int
the number of classes in the dataset
ignore_index : int, default -100
the ground truth label to ignore
ignored_classes : set, optional
the class indices to ignore in computation
exclusive : bool, default True
False for multi-label classification tasks
average : {"macro", "micro", "none"}
the method to average the results over classes
tag_average : {"macro", "micro", "none"}
the method to average the results over meta tags (if given)
delta : int, default 0
the number of frames to add to each ground truth interval before computing the intersection,
see description of the class for details
smooth_interval : int, default 0
intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
see description of the class for details
iou_threshold_long : float, default 0.5
the intersection threshold for segments longer than long_length frames (between 0 and 1),
see description of the class for details
iou_threshold_short : float, default 0.5
the intersection threshold for segments shorter than short_length frames (between 0 and 1),
see description of the class for details
short_length : int, default 30
the threshold number of frames for short intervals that will have an intersection threshold of
iou_threshold_short, see description of the class for details
long_length : int, default 300
the threshold number of frames for long intervals that will have an intersection threshold of
iou_threshold_long, see description of the class for details
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1511class SemiSegmentalF1(_SemiSegmentalMetric): 1512 """The F1 score for semi-segmental recall and precision (not advised for training).""" 1513 1514 def __init__( 1515 self, 1516 num_classes: int, 1517 ignore_index: int = -100, 1518 ignored_classes: Set = None, 1519 exclusive: bool = True, 1520 average: str = "macro", 1521 tag_average: str = "micro", 1522 delta: int = 0, 1523 smooth_interval: int = 0, 1524 iou_threshold_long: float = 0.5, 1525 iou_threshold_short: float = 0.5, 1526 short_length: int = 30, 1527 long_length: int = 300, 1528 threshold_value: Union[float, List] = None, 1529 ) -> None: 1530 """Initialize the class. 1531 1532 Parameters 1533 ---------- 1534 num_classes : int 1535 the number of classes in the dataset 1536 ignore_index : int, default -100 1537 the ground truth label to ignore 1538 ignored_classes : set, optional 1539 the class indices to ignore in computation 1540 exclusive : bool, default True 1541 `False` for multi-label classification tasks 1542 average : {"macro", "micro", "none"} 1543 the method to average the results over classes 1544 tag_average : {"macro", "micro", "none"} 1545 the method to average the results over meta tags (if given) 1546 delta : int, default 0 1547 the number of frames to add to each ground truth interval before computing the intersection, 1548 see description of the class for details 1549 smooth_interval : int, default 0 1550 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1551 see description of the class for details 1552 iou_threshold_long : float, default 0.5 1553 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1554 see description of the class for details 1555 iou_threshold_short : float, default 0.5 1556 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1557 see description of the class for details 1558 short_length : int, default 30 1559 the threshold number of frames for short intervals that will have an intersection threshold of 1560 `iou_threshold_short`, see description of the class for details 1561 long_length : int, default 300 1562 the threshold number of frames for long intervals that will have an intersection threshold of 1563 `iou_threshold_long`, see description of the class for details 1564 threshold_value : float | list, optional 1565 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1566 1567 """ 1568 super().__init__( 1569 num_classes, 1570 ignore_index, 1571 ignored_classes, 1572 exclusive, 1573 average, 1574 tag_average, 1575 delta, 1576 smooth_interval, 1577 iou_threshold_long, 1578 iou_threshold_short, 1579 short_length, 1580 long_length, 1581 [threshold_value], 1582 ) 1583 1584 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1585 """Calculate the metric value from true and false positive and negative rates.""" 1586 if self.optimize: 1587 scores = [] 1588 for k in self.threshold_values: 1589 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 1590 precision = tn[k] / (tn[k] + fp[k] + 1e-7) 1591 scores.append(2 * recall * precision / (recall + precision + 1e-7)) 1592 f1 = max(scores) 1593 else: 1594 k = self.threshold_values[0] 1595 if isinstance(k, list): 1596 k = ", ".join(map(str, k)) 1597 recall = tp[k] / (tp[k] + fn[k] + 1e-7) 1598 precision = tn[k] / (tn[k] + fp[k] + 1e-7) 1599 f1 = 2 * recall * precision / (recall + precision + 1e-7) 1600 return f1
The F1 score for semi-segmental recall and precision (not advised for training).
1514 def __init__( 1515 self, 1516 num_classes: int, 1517 ignore_index: int = -100, 1518 ignored_classes: Set = None, 1519 exclusive: bool = True, 1520 average: str = "macro", 1521 tag_average: str = "micro", 1522 delta: int = 0, 1523 smooth_interval: int = 0, 1524 iou_threshold_long: float = 0.5, 1525 iou_threshold_short: float = 0.5, 1526 short_length: int = 30, 1527 long_length: int = 300, 1528 threshold_value: Union[float, List] = None, 1529 ) -> None: 1530 """Initialize the class. 1531 1532 Parameters 1533 ---------- 1534 num_classes : int 1535 the number of classes in the dataset 1536 ignore_index : int, default -100 1537 the ground truth label to ignore 1538 ignored_classes : set, optional 1539 the class indices to ignore in computation 1540 exclusive : bool, default True 1541 `False` for multi-label classification tasks 1542 average : {"macro", "micro", "none"} 1543 the method to average the results over classes 1544 tag_average : {"macro", "micro", "none"} 1545 the method to average the results over meta tags (if given) 1546 delta : int, default 0 1547 the number of frames to add to each ground truth interval before computing the intersection, 1548 see description of the class for details 1549 smooth_interval : int, default 0 1550 intervals shorter than this number of frames will be ignored (both in prediction and in ground truth, 1551 see description of the class for details 1552 iou_threshold_long : float, default 0.5 1553 the intersection threshold for segments longer than `long_length` frames (between 0 and 1), 1554 see description of the class for details 1555 iou_threshold_short : float, default 0.5 1556 the intersection threshold for segments shorter than `short_length` frames (between 0 and 1), 1557 see description of the class for details 1558 short_length : int, default 30 1559 the threshold number of frames for short intervals that will have an intersection threshold of 1560 `iou_threshold_short`, see description of the class for details 1561 long_length : int, default 300 1562 the threshold number of frames for long intervals that will have an intersection threshold of 1563 `iou_threshold_long`, see description of the class for details 1564 threshold_value : float | list, optional 1565 the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default) 1566 1567 """ 1568 super().__init__( 1569 num_classes, 1570 ignore_index, 1571 ignored_classes, 1572 exclusive, 1573 average, 1574 tag_average, 1575 delta, 1576 smooth_interval, 1577 iou_threshold_long, 1578 iou_threshold_short, 1579 short_length, 1580 long_length, 1581 [threshold_value], 1582 )
Initialize the class.
Parameters
num_classes : int
the number of classes in the dataset
ignore_index : int, default -100
the ground truth label to ignore
ignored_classes : set, optional
the class indices to ignore in computation
exclusive : bool, default True
False for multi-label classification tasks
average : {"macro", "micro", "none"}
the method to average the results over classes
tag_average : {"macro", "micro", "none"}
the method to average the results over meta tags (if given)
delta : int, default 0
the number of frames to add to each ground truth interval before computing the intersection,
see description of the class for details
smooth_interval : int, default 0
intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
see description of the class for details
iou_threshold_long : float, default 0.5
the intersection threshold for segments longer than long_length frames (between 0 and 1),
see description of the class for details
iou_threshold_short : float, default 0.5
the intersection threshold for segments shorter than short_length frames (between 0 and 1),
see description of the class for details
short_length : int, default 30
the threshold number of frames for short intervals that will have an intersection threshold of
iou_threshold_short, see description of the class for details
long_length : int, default 300
the threshold number of frames for long intervals that will have an intersection threshold of
iou_threshold_long, see description of the class for details
threshold_value : float | list, optional
the decision threshold value (cannot be defined for exclusive classification, and 0.5 be default)
1603class SemiSegmentalPR_AUC(_SemiSegmentalMetric): 1604 """The area under the precision-recall curve for semi-segmental metrics (not advised for training).""" 1605 1606 def __init__( 1607 self, 1608 num_classes: int, 1609 ignore_index: int = -100, 1610 ignored_classes: Set = None, 1611 exclusive: bool = True, 1612 average: str = "macro", 1613 tag_average: str = "micro", 1614 delta: int = 0, 1615 smooth_interval: int = 0, 1616 iou_threshold_long: float = 0.5, 1617 iou_threshold_short: float = 0.5, 1618 short_length: int = 30, 1619 long_length: int = 300, 1620 threshold_step: float = 0.1, 1621 ) -> None: 1622 super().__init__( 1623 num_classes, 1624 ignore_index, 1625 ignored_classes, 1626 exclusive, 1627 average, 1628 tag_average, 1629 delta, 1630 smooth_interval, 1631 iou_threshold_long, 1632 iou_threshold_short, 1633 short_length, 1634 long_length, 1635 list(np.arange(0, 1, threshold_step)), 1636 ) 1637 1638 def _calculate_metric(self, tp: Dict, fp: Dict, fn: Dict, tn: Dict) -> float: 1639 """Calculate the metric value from true and false positive and negative rates.""" 1640 precisions = [] 1641 recalls = [] 1642 for k in sorted(self.threshold_values): 1643 precisions.append(tn[k] / (tn[k] + fp[k] + 1e-7)) 1644 recalls.append(tp[k] / (tp[k] + fn[k] + 1e-7)) 1645 return metrics.auc(x=recalls, y=precisions)
The area under the precision-recall curve for semi-segmental metrics (not advised for training).
1606 def __init__( 1607 self, 1608 num_classes: int, 1609 ignore_index: int = -100, 1610 ignored_classes: Set = None, 1611 exclusive: bool = True, 1612 average: str = "macro", 1613 tag_average: str = "micro", 1614 delta: int = 0, 1615 smooth_interval: int = 0, 1616 iou_threshold_long: float = 0.5, 1617 iou_threshold_short: float = 0.5, 1618 short_length: int = 30, 1619 long_length: int = 300, 1620 threshold_step: float = 0.1, 1621 ) -> None: 1622 super().__init__( 1623 num_classes, 1624 ignore_index, 1625 ignored_classes, 1626 exclusive, 1627 average, 1628 tag_average, 1629 delta, 1630 smooth_interval, 1631 iou_threshold_long, 1632 iou_threshold_short, 1633 short_length, 1634 long_length, 1635 list(np.arange(0, 1, threshold_step)), 1636 )
Initialize the class.
Parameters
num_classes : int
the number of classes in the dataset
ignore_index : int, default -100
the ground truth label to ignore
ignored_classes : set, optional
the class indices to ignore in computation
exclusive : bool, default True
False for multi-label classification tasks
average : {"macro", "micro", "none"}
the method to average the results over classes
tag_average : {"macro", "micro", "none"}
the method to average the results over meta tags (if given)
delta : int, default 0
the number of frames to add to each ground truth interval before computing the intersection,
see description of the class for details
smooth_interval : int, default 0
intervals shorter than this number of frames will be ignored (both in prediction and in ground truth,
see description of the class for details
iou_threshold_long : float, default 0.5
the intersection threshold for segments longer than long_length frames (between 0 and 1),
see description of the class for details
iou_threshold_short : float, default 0.5
the intersection threshold for segments shorter than short_length frames (between 0 and 1),
see description of the class for details
short_length : int, default 30
the threshold number of frames for short intervals that will have an intersection threshold of
iou_threshold_short, see description of the class for details
long_length : int, default 300
the threshold number of frames for long intervals that will have an intersection threshold of
iou_threshold_long, see description of the class for details
threshold_values : list, optional
the decision threshold values (cannot be defined for exclusive classification, and 0.5 be default)
1648class Accuracy(Metric): 1649 """Accuracy.""" 1650 1651 def __init__(self, ignore_index=-100): 1652 """Initialize the class. 1653 1654 Parameters 1655 ---------- 1656 ignore_index: int 1657 the class index that indicates ignored sample 1658 1659 """ 1660 super().__init__() 1661 self.ignore_index = ignore_index 1662 1663 def reset(self) -> None: 1664 """Reset the internal state (at the beginning of an epoch).""" 1665 self.total = 0 1666 self.correct = 0 1667 1668 def calculate(self) -> float: 1669 """Calculate the metric value. 1670 1671 Returns 1672 ------- 1673 metric : float 1674 metric value 1675 1676 """ 1677 return self.correct / (self.total + 1e-7) 1678 1679 def update( 1680 self, 1681 predicted: torch.Tensor, 1682 target: torch.Tensor, 1683 tags: torch.Tensor = None, 1684 ) -> None: 1685 """Update the internal state (with a batch). 1686 1687 Parameters 1688 ---------- 1689 predicted : torch.Tensor 1690 the main prediction tensor generated by the model 1691 ssl_predicted : torch.Tensor 1692 the SSL prediction tensor generated by the model 1693 target : torch.Tensor 1694 the corresponding main target tensor 1695 ssl_target : torch.Tensor 1696 the corresponding SSL target tensor 1697 tags : torch.Tensor 1698 the tensor of meta tags (or `None`, if tags are not given) 1699 1700 """ 1701 mask = target != self.ignore_index 1702 self.total += torch.sum(mask) 1703 self.correct += torch.sum((target == predicted)[mask])
Accuracy.
1651 def __init__(self, ignore_index=-100): 1652 """Initialize the class. 1653 1654 Parameters 1655 ---------- 1656 ignore_index: int 1657 the class index that indicates ignored sample 1658 1659 """ 1660 super().__init__() 1661 self.ignore_index = ignore_index
Initialize the class.
Parameters
ignore_index: int the class index that indicates ignored sample
1663 def reset(self) -> None: 1664 """Reset the internal state (at the beginning of an epoch).""" 1665 self.total = 0 1666 self.correct = 0
Reset the internal state (at the beginning of an epoch).
1668 def calculate(self) -> float: 1669 """Calculate the metric value. 1670 1671 Returns 1672 ------- 1673 metric : float 1674 metric value 1675 1676 """ 1677 return self.correct / (self.total + 1e-7)
Calculate the metric value.
Returns
metric : float metric value
1679 def update( 1680 self, 1681 predicted: torch.Tensor, 1682 target: torch.Tensor, 1683 tags: torch.Tensor = None, 1684 ) -> None: 1685 """Update the internal state (with a batch). 1686 1687 Parameters 1688 ---------- 1689 predicted : torch.Tensor 1690 the main prediction tensor generated by the model 1691 ssl_predicted : torch.Tensor 1692 the SSL prediction tensor generated by the model 1693 target : torch.Tensor 1694 the corresponding main target tensor 1695 ssl_target : torch.Tensor 1696 the corresponding SSL target tensor 1697 tags : torch.Tensor 1698 the tensor of meta tags (or `None`, if tags are not given) 1699 1700 """ 1701 mask = target != self.ignore_index 1702 self.total += torch.sum(mask) 1703 self.correct += torch.sum((target == predicted)[mask])
Update the internal state (with a batch).
Parameters
predicted : torch.Tensor
the main prediction tensor generated by the model
ssl_predicted : torch.Tensor
the SSL prediction tensor generated by the model
target : torch.Tensor
the corresponding main target tensor
ssl_target : torch.Tensor
the corresponding SSL target tensor
tags : torch.Tensor
the tensor of meta tags (or None, if tags are not given)
Inherited Members
1706class Count(Metric): 1707 """Fraction of samples labeled by the model as a class.""" 1708 1709 def __init__(self, classes: Set, exclusive: bool = True): 1710 """Initialize the class. 1711 1712 Parameters 1713 ---------- 1714 classes : set 1715 the set of classes to count 1716 exclusive: bool, default True 1717 set to False for multi-label classification tasks 1718 1719 """ 1720 super().__init__() 1721 self.classes = classes 1722 self.exclusive = exclusive 1723 1724 def reset(self) -> None: 1725 """Reset the internal state (at the beginning of an epoch).""" 1726 self.count = defaultdict(lambda: 0) 1727 self.total = 0 1728 1729 def update( 1730 self, 1731 predicted: torch.Tensor, 1732 target: torch.Tensor, 1733 tags: torch.Tensor, 1734 ) -> None: 1735 """Update the internal state (with a batch). 1736 1737 Parameters 1738 ---------- 1739 predicted : torch.Tensor 1740 the main prediction tensor generated by the model 1741 ssl_predicted : torch.Tensor 1742 the SSL prediction tensor generated by the model 1743 target : torch.Tensor 1744 the corresponding main target tensor 1745 ssl_target : torch.Tensor 1746 the corresponding SSL target tensor 1747 tags : torch.Tensor 1748 the tensor of meta tags (or `None`, if tags are not given) 1749 1750 """ 1751 if self.exclusive: 1752 for c in self.classes: 1753 self.count[c] += torch.sum(predicted == c) 1754 self.total += torch.numel(predicted) 1755 else: 1756 for c in self.classes: 1757 self.count[c] += torch.sum(predicted[:, c, :] == 1) 1758 self.total += torch.numel(predicted[:, 0, :]) 1759 1760 def calculate(self) -> Dict: 1761 """Calculate the metric (at the end of an epoch). 1762 1763 Returns 1764 ------- 1765 result : dict 1766 a dictionary where the keys are class indices and the values are class metric values 1767 1768 """ 1769 for c in self.classes: 1770 self.count[c] = self.count[c] / (self.total + 1e-7) 1771 return dict(self.count)
Fraction of samples labeled by the model as a class.
1709 def __init__(self, classes: Set, exclusive: bool = True): 1710 """Initialize the class. 1711 1712 Parameters 1713 ---------- 1714 classes : set 1715 the set of classes to count 1716 exclusive: bool, default True 1717 set to False for multi-label classification tasks 1718 1719 """ 1720 super().__init__() 1721 self.classes = classes 1722 self.exclusive = exclusive
Initialize the class.
Parameters
classes : set the set of classes to count exclusive: bool, default True set to False for multi-label classification tasks
1724 def reset(self) -> None: 1725 """Reset the internal state (at the beginning of an epoch).""" 1726 self.count = defaultdict(lambda: 0) 1727 self.total = 0
Reset the internal state (at the beginning of an epoch).
1729 def update( 1730 self, 1731 predicted: torch.Tensor, 1732 target: torch.Tensor, 1733 tags: torch.Tensor, 1734 ) -> None: 1735 """Update the internal state (with a batch). 1736 1737 Parameters 1738 ---------- 1739 predicted : torch.Tensor 1740 the main prediction tensor generated by the model 1741 ssl_predicted : torch.Tensor 1742 the SSL prediction tensor generated by the model 1743 target : torch.Tensor 1744 the corresponding main target tensor 1745 ssl_target : torch.Tensor 1746 the corresponding SSL target tensor 1747 tags : torch.Tensor 1748 the tensor of meta tags (or `None`, if tags are not given) 1749 1750 """ 1751 if self.exclusive: 1752 for c in self.classes: 1753 self.count[c] += torch.sum(predicted == c) 1754 self.total += torch.numel(predicted) 1755 else: 1756 for c in self.classes: 1757 self.count[c] += torch.sum(predicted[:, c, :] == 1) 1758 self.total += torch.numel(predicted[:, 0, :])
Update the internal state (with a batch).
Parameters
predicted : torch.Tensor
the main prediction tensor generated by the model
ssl_predicted : torch.Tensor
the SSL prediction tensor generated by the model
target : torch.Tensor
the corresponding main target tensor
ssl_target : torch.Tensor
the corresponding SSL target tensor
tags : torch.Tensor
the tensor of meta tags (or None, if tags are not given)
1760 def calculate(self) -> Dict: 1761 """Calculate the metric (at the end of an epoch). 1762 1763 Returns 1764 ------- 1765 result : dict 1766 a dictionary where the keys are class indices and the values are class metric values 1767 1768 """ 1769 for c in self.classes: 1770 self.count[c] = self.count[c] / (self.total + 1e-7) 1771 return dict(self.count)
Calculate the metric (at the end of an epoch).
Returns
result : dict a dictionary where the keys are class indices and the values are class metric values
Inherited Members
1774class EditDistance(Metric): 1775 """Edit distance (not advised for training). 1776 1777 Normalized by the length of the sequences. 1778 """ 1779 1780 def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None: 1781 """Initialize the class. 1782 1783 Parameters 1784 ---------- 1785 ignore_index : int, default -100 1786 the class index that indicates samples that should be ignored 1787 exclusive : bool, default True 1788 set to False for multi-label classification tasks 1789 1790 """ 1791 super().__init__() 1792 self.ignore_index = ignore_index 1793 self.exclusive = exclusive 1794 1795 def reset(self) -> None: 1796 """Reset the internal state (at the beginning of an epoch).""" 1797 self.edit_distance = 0 1798 self.total = 0 1799 1800 def update( 1801 self, 1802 predicted: torch.Tensor, 1803 target: torch.Tensor, 1804 tags: torch.Tensor, 1805 ) -> None: 1806 """Update the internal state (with a batch). 1807 1808 Parameters 1809 ---------- 1810 predicted : torch.Tensor 1811 the main prediction tensor generated by the model 1812 ssl_predicted : torch.Tensor 1813 the SSL prediction tensor generated by the model 1814 target : torch.Tensor 1815 the corresponding main target tensor 1816 ssl_target : torch.Tensor 1817 the corresponding SSL target tensor 1818 tags : torch.Tensor 1819 the tensor of meta tags (or `None`, if tags are not given) 1820 1821 """ 1822 mask = target != self.ignore_index 1823 self.total += torch.sum(mask) 1824 if self.exclusive: 1825 predicted = predicted[mask].flatten() 1826 target = target[mask].flatten() 1827 self.edit_distance += editdistance.eval( 1828 predicted.detach().cpu().numpy(), target.detach().cpu().numpy() 1829 ) 1830 else: 1831 for c in range(target.shape[1]): 1832 predicted_class = predicted[:, c, :][mask[:, c, :]].flatten() 1833 target_class = target[:, c, :][mask[:, c, :]].flatten() 1834 self.edit_distance += editdistance.eval( 1835 predicted_class.detach().cpu().tolist(), 1836 target_class.detach().cpu().tolist(), 1837 ) 1838 1839 def _is_equal(self, a, b): 1840 """Compare while ignoring samples marked with ignore_index.""" 1841 if self.ignore_index in [a, b] or a == b: 1842 return True 1843 else: 1844 return False 1845 1846 def calculate(self) -> float: 1847 """Calculate the metric (at the end of an epoch). 1848 1849 Returns 1850 ------- 1851 result : float 1852 the metric value 1853 1854 """ 1855 return self.edit_distance / (self.total + 1e-7)
Edit distance (not advised for training).
Normalized by the length of the sequences.
1780 def __init__(self, ignore_index: int = -100, exclusive: bool = True) -> None: 1781 """Initialize the class. 1782 1783 Parameters 1784 ---------- 1785 ignore_index : int, default -100 1786 the class index that indicates samples that should be ignored 1787 exclusive : bool, default True 1788 set to False for multi-label classification tasks 1789 1790 """ 1791 super().__init__() 1792 self.ignore_index = ignore_index 1793 self.exclusive = exclusive
Initialize the class.
Parameters
ignore_index : int, default -100 the class index that indicates samples that should be ignored exclusive : bool, default True set to False for multi-label classification tasks
1795 def reset(self) -> None: 1796 """Reset the internal state (at the beginning of an epoch).""" 1797 self.edit_distance = 0 1798 self.total = 0
Reset the internal state (at the beginning of an epoch).
1800 def update( 1801 self, 1802 predicted: torch.Tensor, 1803 target: torch.Tensor, 1804 tags: torch.Tensor, 1805 ) -> None: 1806 """Update the internal state (with a batch). 1807 1808 Parameters 1809 ---------- 1810 predicted : torch.Tensor 1811 the main prediction tensor generated by the model 1812 ssl_predicted : torch.Tensor 1813 the SSL prediction tensor generated by the model 1814 target : torch.Tensor 1815 the corresponding main target tensor 1816 ssl_target : torch.Tensor 1817 the corresponding SSL target tensor 1818 tags : torch.Tensor 1819 the tensor of meta tags (or `None`, if tags are not given) 1820 1821 """ 1822 mask = target != self.ignore_index 1823 self.total += torch.sum(mask) 1824 if self.exclusive: 1825 predicted = predicted[mask].flatten() 1826 target = target[mask].flatten() 1827 self.edit_distance += editdistance.eval( 1828 predicted.detach().cpu().numpy(), target.detach().cpu().numpy() 1829 ) 1830 else: 1831 for c in range(target.shape[1]): 1832 predicted_class = predicted[:, c, :][mask[:, c, :]].flatten() 1833 target_class = target[:, c, :][mask[:, c, :]].flatten() 1834 self.edit_distance += editdistance.eval( 1835 predicted_class.detach().cpu().tolist(), 1836 target_class.detach().cpu().tolist(), 1837 )
Update the internal state (with a batch).
Parameters
predicted : torch.Tensor
the main prediction tensor generated by the model
ssl_predicted : torch.Tensor
the SSL prediction tensor generated by the model
target : torch.Tensor
the corresponding main target tensor
ssl_target : torch.Tensor
the corresponding SSL target tensor
tags : torch.Tensor
the tensor of meta tags (or None, if tags are not given)
1846 def calculate(self) -> float: 1847 """Calculate the metric (at the end of an epoch). 1848 1849 Returns 1850 ------- 1851 result : float 1852 the metric value 1853 1854 """ 1855 return self.edit_distance / (self.total + 1e-7)
Calculate the metric (at the end of an epoch).
Returns
result : float the metric value
Inherited Members
1858class mAP(Metric): 1859 """Mean average precision (segmental) (not advised for training).""" 1860 1861 needs_raw_data = True 1862 1863 def __init__( 1864 self, 1865 exclusive, 1866 num_classes, 1867 average="macro", 1868 iou_threshold=0.5, 1869 threshold_value=0.5, 1870 ignored_classes=None, 1871 ): 1872 """Initialize the class. 1873 1874 Parameters 1875 ---------- 1876 exclusive : bool 1877 set to False for multi-label classification tasks 1878 num_classes : int 1879 the number of classes 1880 average : {"macro", "micro", "none"} 1881 the type of averaging to perform 1882 iou_threshold : float, default 0.5 1883 the IoU threshold for matching 1884 threshold_value : float, default 0.5 1885 the threshold value for binarization 1886 ignored_classes : list, default None 1887 the list of classes to ignore 1888 1889 """ 1890 if ignored_classes is None: 1891 ignored_classes = [] 1892 self.average = average 1893 self.iou_threshold = iou_threshold 1894 self.threshold = threshold_value 1895 self.exclusive = exclusive 1896 self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes] 1897 super().__init__() 1898 1899 def match(self, lst, ratio, ground): 1900 """Given a list of proposals, match them to the ground truth boxes.""" 1901 lst = sorted(lst, key=lambda x: x[2]) 1902 1903 def overlap(prop, ground): 1904 s_p, e_p, _ = prop 1905 s_g, e_g, _ = ground 1906 return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g)) 1907 1908 cos_map = [-1 for x in range(len(lst))] 1909 count_map = [0 for x in range(len(ground))] 1910 1911 for x in range(len(lst)): 1912 for y in range(len(ground)): 1913 if overlap(lst[x], ground[y]) < ratio: 1914 continue 1915 if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]): 1916 continue 1917 cos_map[x] = y 1918 if cos_map[x] != -1: 1919 count_map[cos_map[x]] += 1 1920 positive = sum([(x > 0) for x in count_map]) 1921 return cos_map, count_map, positive, [x[2] for x in lst] 1922 1923 def reset(self) -> None: 1924 """Reset the internal state (at the beginning of an epoch).""" 1925 self.count_map = defaultdict(lambda: []) 1926 self.positive = defaultdict(lambda: 0) 1927 self.cos_map = defaultdict(lambda: []) 1928 self.confidence = defaultdict(lambda: []) 1929 1930 def calc_pr(self, positive, proposal, ground): 1931 """Get precision.""" 1932 if proposal == 0: 1933 return 0, 0 1934 if ground == 0: 1935 return 0, 0 1936 return (1.0 * positive) / proposal, (1.0 * positive) / ground 1937 1938 def calculate(self) -> Union[float, Dict]: 1939 """Calculate the metric (at the end of an epoch).""" 1940 if self.average == "micro": 1941 confidence = [] 1942 count_map = [] 1943 cos_map = [] 1944 positive = sum(self.positive.values()) 1945 for key in self.count_map.keys(): 1946 confidence += self.confidence[key] 1947 cos_map += list(np.array(self.cos_map[key]) + len(count_map)) 1948 count_map += self.count_map[key] 1949 return self.ap(cos_map, count_map, positive, confidence) 1950 results = { 1951 key: self.ap( 1952 self.cos_map[key], 1953 self.count_map[key], 1954 self.positive[key], 1955 self.confidence[key], 1956 ) 1957 for key in self.count_map.keys() 1958 } 1959 if self.average == "none": 1960 return results 1961 else: 1962 return float(np.mean(list(results.values()))) 1963 1964 def ap(self, cos_map, count_map, positive, confidence): 1965 """Compute average precision.""" 1966 indices = np.argsort(confidence) 1967 cos_map = list(np.array(cos_map)[indices]) 1968 score = 0 1969 number_proposal = len(cos_map) 1970 number_ground = len(count_map) 1971 old_precision, old_recall = self.calc_pr( 1972 positive, number_proposal, number_ground 1973 ) 1974 1975 for x in range(len(cos_map)): 1976 number_proposal -= 1 1977 if cos_map[x] == -1: 1978 continue 1979 count_map[cos_map[x]] -= 1 1980 if count_map[cos_map[x]] == 0: 1981 positive -= 1 1982 1983 precision, recall = self.calc_pr(positive, number_proposal, number_ground) 1984 if precision > old_precision: 1985 old_precision = precision 1986 score += old_precision * (old_recall - recall) 1987 old_recall = recall 1988 return score 1989 1990 def _get_intervals( 1991 self, tensor: torch.Tensor, probability: torch.Tensor = None 1992 ) -> Union[Tuple, torch.Tensor]: 1993 """Get `True` group beginning and end indices from a boolean tensor and average probability over these intervals.""" 1994 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 1995 true_indices = torch.where(output)[0] 1996 starts = torch.tensor( 1997 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 1998 ) 1999 ends = torch.tensor( 2000 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 2001 ) 2002 confidence = torch.tensor( 2003 [probability[indices == i].mean() for i in true_indices] 2004 ) 2005 return torch.stack([starts, ends, confidence]).T 2006 2007 def update( 2008 self, 2009 predicted: torch.Tensor, 2010 target: torch.Tensor, 2011 tags: torch.Tensor, 2012 ) -> None: 2013 """Update the state (at the end of each batch).""" 2014 predicted = torch.cat( 2015 [ 2016 copy(predicted), 2017 -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device), 2018 ], 2019 dim=-1, 2020 ) 2021 target = torch.cat( 2022 [ 2023 copy(target), 2024 -100 * torch.ones((*target.shape[:-1], 1)).to(target.device), 2025 ], 2026 dim=-1, 2027 ) 2028 num_classes = predicted.shape[1] 2029 predicted = predicted.transpose(1, 2).reshape(-1, num_classes) 2030 if self.exclusive: 2031 target = target.flatten() 2032 else: 2033 target = target.transpose(1, 2).reshape(-1, num_classes) 2034 probability = copy(predicted) 2035 if not self.exclusive: 2036 predicted = (predicted > self.threshold).int() 2037 else: 2038 predicted = torch.max(predicted, 1)[1] 2039 for c in self.classes: 2040 if self.exclusive: 2041 predicted_intervals = self._get_intervals( 2042 predicted == c, probability=probability[:, c] 2043 ) 2044 target_intervals = self._get_intervals( 2045 target == c, probability=probability[:, c] 2046 ) 2047 else: 2048 predicted_intervals = self._get_intervals( 2049 predicted[:, c] == 1, probability=probability[:, c] 2050 ) 2051 target_intervals = self._get_intervals( 2052 target[:, c] == 1, probability=probability[:, c] 2053 ) 2054 cos_map, count_map, positive, confidence = self.match( 2055 predicted_intervals, self.iou_threshold, target_intervals 2056 ) 2057 cos_map = np.array(cos_map) 2058 cos_map[cos_map != -1] += len(self.count_map[c]) 2059 self.cos_map[c] += list(cos_map) 2060 self.count_map[c] += count_map 2061 self.confidence[c] += confidence 2062 self.positive[c] += positive
Mean average precision (segmental) (not advised for training).
1863 def __init__( 1864 self, 1865 exclusive, 1866 num_classes, 1867 average="macro", 1868 iou_threshold=0.5, 1869 threshold_value=0.5, 1870 ignored_classes=None, 1871 ): 1872 """Initialize the class. 1873 1874 Parameters 1875 ---------- 1876 exclusive : bool 1877 set to False for multi-label classification tasks 1878 num_classes : int 1879 the number of classes 1880 average : {"macro", "micro", "none"} 1881 the type of averaging to perform 1882 iou_threshold : float, default 0.5 1883 the IoU threshold for matching 1884 threshold_value : float, default 0.5 1885 the threshold value for binarization 1886 ignored_classes : list, default None 1887 the list of classes to ignore 1888 1889 """ 1890 if ignored_classes is None: 1891 ignored_classes = [] 1892 self.average = average 1893 self.iou_threshold = iou_threshold 1894 self.threshold = threshold_value 1895 self.exclusive = exclusive 1896 self.classes = [x for x in list(range(num_classes)) if x not in ignored_classes] 1897 super().__init__()
Initialize the class.
Parameters
exclusive : bool set to False for multi-label classification tasks num_classes : int the number of classes average : {"macro", "micro", "none"} the type of averaging to perform iou_threshold : float, default 0.5 the IoU threshold for matching threshold_value : float, default 0.5 the threshold value for binarization ignored_classes : list, default None the list of classes to ignore
If True, dlc2action.task.universal_task.Task will pass raw data to the metric (only primary predict
function applied).
Otherwise it will pass a prediction for the classes.
1899 def match(self, lst, ratio, ground): 1900 """Given a list of proposals, match them to the ground truth boxes.""" 1901 lst = sorted(lst, key=lambda x: x[2]) 1902 1903 def overlap(prop, ground): 1904 s_p, e_p, _ = prop 1905 s_g, e_g, _ = ground 1906 return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g)) 1907 1908 cos_map = [-1 for x in range(len(lst))] 1909 count_map = [0 for x in range(len(ground))] 1910 1911 for x in range(len(lst)): 1912 for y in range(len(ground)): 1913 if overlap(lst[x], ground[y]) < ratio: 1914 continue 1915 if overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]]): 1916 continue 1917 cos_map[x] = y 1918 if cos_map[x] != -1: 1919 count_map[cos_map[x]] += 1 1920 positive = sum([(x > 0) for x in count_map]) 1921 return cos_map, count_map, positive, [x[2] for x in lst]
Given a list of proposals, match them to the ground truth boxes.
1923 def reset(self) -> None: 1924 """Reset the internal state (at the beginning of an epoch).""" 1925 self.count_map = defaultdict(lambda: []) 1926 self.positive = defaultdict(lambda: 0) 1927 self.cos_map = defaultdict(lambda: []) 1928 self.confidence = defaultdict(lambda: [])
Reset the internal state (at the beginning of an epoch).
1930 def calc_pr(self, positive, proposal, ground): 1931 """Get precision.""" 1932 if proposal == 0: 1933 return 0, 0 1934 if ground == 0: 1935 return 0, 0 1936 return (1.0 * positive) / proposal, (1.0 * positive) / ground
Get precision.
1938 def calculate(self) -> Union[float, Dict]: 1939 """Calculate the metric (at the end of an epoch).""" 1940 if self.average == "micro": 1941 confidence = [] 1942 count_map = [] 1943 cos_map = [] 1944 positive = sum(self.positive.values()) 1945 for key in self.count_map.keys(): 1946 confidence += self.confidence[key] 1947 cos_map += list(np.array(self.cos_map[key]) + len(count_map)) 1948 count_map += self.count_map[key] 1949 return self.ap(cos_map, count_map, positive, confidence) 1950 results = { 1951 key: self.ap( 1952 self.cos_map[key], 1953 self.count_map[key], 1954 self.positive[key], 1955 self.confidence[key], 1956 ) 1957 for key in self.count_map.keys() 1958 } 1959 if self.average == "none": 1960 return results 1961 else: 1962 return float(np.mean(list(results.values())))
Calculate the metric (at the end of an epoch).
1964 def ap(self, cos_map, count_map, positive, confidence): 1965 """Compute average precision.""" 1966 indices = np.argsort(confidence) 1967 cos_map = list(np.array(cos_map)[indices]) 1968 score = 0 1969 number_proposal = len(cos_map) 1970 number_ground = len(count_map) 1971 old_precision, old_recall = self.calc_pr( 1972 positive, number_proposal, number_ground 1973 ) 1974 1975 for x in range(len(cos_map)): 1976 number_proposal -= 1 1977 if cos_map[x] == -1: 1978 continue 1979 count_map[cos_map[x]] -= 1 1980 if count_map[cos_map[x]] == 0: 1981 positive -= 1 1982 1983 precision, recall = self.calc_pr(positive, number_proposal, number_ground) 1984 if precision > old_precision: 1985 old_precision = precision 1986 score += old_precision * (old_recall - recall) 1987 old_recall = recall 1988 return score
Compute average precision.
2007 def update( 2008 self, 2009 predicted: torch.Tensor, 2010 target: torch.Tensor, 2011 tags: torch.Tensor, 2012 ) -> None: 2013 """Update the state (at the end of each batch).""" 2014 predicted = torch.cat( 2015 [ 2016 copy(predicted), 2017 -100 * torch.ones((*predicted.shape[:-1], 1)).to(predicted.device), 2018 ], 2019 dim=-1, 2020 ) 2021 target = torch.cat( 2022 [ 2023 copy(target), 2024 -100 * torch.ones((*target.shape[:-1], 1)).to(target.device), 2025 ], 2026 dim=-1, 2027 ) 2028 num_classes = predicted.shape[1] 2029 predicted = predicted.transpose(1, 2).reshape(-1, num_classes) 2030 if self.exclusive: 2031 target = target.flatten() 2032 else: 2033 target = target.transpose(1, 2).reshape(-1, num_classes) 2034 probability = copy(predicted) 2035 if not self.exclusive: 2036 predicted = (predicted > self.threshold).int() 2037 else: 2038 predicted = torch.max(predicted, 1)[1] 2039 for c in self.classes: 2040 if self.exclusive: 2041 predicted_intervals = self._get_intervals( 2042 predicted == c, probability=probability[:, c] 2043 ) 2044 target_intervals = self._get_intervals( 2045 target == c, probability=probability[:, c] 2046 ) 2047 else: 2048 predicted_intervals = self._get_intervals( 2049 predicted[:, c] == 1, probability=probability[:, c] 2050 ) 2051 target_intervals = self._get_intervals( 2052 target[:, c] == 1, probability=probability[:, c] 2053 ) 2054 cos_map, count_map, positive, confidence = self.match( 2055 predicted_intervals, self.iou_threshold, target_intervals 2056 ) 2057 cos_map = np.array(cos_map) 2058 cos_map[cos_map != -1] += len(self.count_map[c]) 2059 self.cos_map[c] += list(cos_map) 2060 self.count_map[c] += count_map 2061 self.confidence[c] += confidence 2062 self.positive[c] += positive
Update the state (at the end of each batch).