dlc2action.data.dataset
Behavior dataset (class that manages high-level data interactions).
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Behavior dataset (class that manages high-level data interactions).""" 8 9import inspect 10import os 11import pickle 12import warnings 13from abc import ABC 14from collections import Counter, defaultdict 15from copy import copy, deepcopy 16from typing import Dict, List, Optional, Tuple, Union 17 18import numpy as np 19import torch 20from dlc2action import options 21from dlc2action.data.base_store import BehaviorStore, InputStore 22from dlc2action.utils import ( 23 apply_threshold, 24 apply_threshold_hysteresis, 25 apply_threshold_max, 26) 27from numpy import ndarray 28from torch.utils.data import Dataset 29from tqdm import tqdm 30 31 32class BehaviorDataset(Dataset, ABC): 33 """A generalized dataset class. 34 35 Data and annotation are stored in separate InputStore and BehaviorStore objects; the dataset class 36 manages their interactions. 37 """ 38 39 def __init__( 40 self, 41 data_type: str, 42 annotation_type: str = "none", 43 ssl_transformations: List = None, 44 saved_data_path: str = None, 45 input_store: InputStore = None, 46 annotation_store: BehaviorStore = None, 47 only_load_annotated: bool = False, 48 recompute_annotation: bool = False, 49 # mask: str = None, 50 ids: List = None, 51 **data_parameters, 52 ) -> None: 53 """Initialize a dataset. 54 55 Parameters 56 ---------- 57 data_type : str 58 the data type (see available types by running BehaviorDataset.data_types()) 59 annotation_type : str 60 the annotation type (see available types by running BehaviorDataset.annotation_types()) 61 ssl_transformations : list 62 a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple 63 saved_data_path : str 64 the path to a pre-computed pickled dataset 65 input_store : InputStore 66 a pre-computed input store 67 annotation_store : BehaviorStore 68 a precomputed annotation store 69 only_load_annotated : bool 70 if `True`, the input files that don't have a matching annotation file will be disregarded 71 recompute_annotation : bool 72 if `True`, the annotation will be recomputed even if a precomputed annotation store is provided 73 ids : list 74 a list of ids to load from the input store 75 *data_parameters : dict 76 parameters to initialize the input and annotation stores 77 78 """ 79 mask = None 80 if len(data_parameters) == 0: 81 recompute_annotation = False 82 feature_extraction = data_parameters.get("feature_extraction") 83 if feature_extraction is not None and not issubclass( 84 options.input_stores[data_type], 85 options.feature_extractors[feature_extraction].input_store_class, 86 ): 87 raise ValueError( 88 f"The {feature_extraction} feature extractor does not work with " 89 f"the {data_type} data type, please choose a suclass of " 90 f"{options.feature_extractors[feature_extraction].input_store_class}" 91 ) 92 if ssl_transformations is None: 93 ssl_transformations = [] 94 self.ssl_transformations = ssl_transformations 95 self.input_type = data_type 96 self.annotation_type = annotation_type 97 self.stats = None 98 if mask is not None: 99 with open(mask, "rb") as f: 100 self.mask = pickle.load(f) 101 else: 102 self.mask = None 103 self.ids = ids 104 self.tag = None 105 self.return_unlabeled = None 106 # load saved key objects for annotation and input if they exist 107 input_key_objects, annotation_key_objects = None, None 108 if saved_data_path is not None: 109 if os.path.exists(saved_data_path): 110 with open(saved_data_path, "rb") as f: 111 input_key_objects, annotation_key_objects = pickle.load(f) 112 # if the input or the annotation store need to be created, generate the common video order 113 if len(data_parameters) > 0: 114 input_files = options.input_stores[data_type].get_file_ids( 115 **data_parameters 116 ) 117 annotation_files = options.annotation_stores[annotation_type].get_file_ids( 118 **data_parameters 119 ) 120 if only_load_annotated: 121 data_parameters["video_order"] = [ 122 x for x in input_files if x in annotation_files 123 ] 124 else: 125 data_parameters["video_order"] = input_files 126 if len(data_parameters["video_order"]) == 0: 127 raise RuntimeError( 128 "The length of file list is 0! Please check your data parameters!" 129 ) 130 data_parameters["mask"] = self.mask 131 # load or create the input store 132 ok = False 133 if input_store is not None: 134 self.input_store = input_store 135 ok = True 136 elif input_key_objects is not None: 137 try: 138 self.input_store = self._load_input_store(data_type, input_key_objects) 139 ok = True 140 except: 141 warnings.warn("Loading input store from key objects failed") 142 if not ok: 143 self.input_store = self._get_input_store( 144 data_type, deepcopy(data_parameters) 145 ) 146 # get the objects needed to create the annotation store (like a clip length dictionary) 147 annotation_objects = self.input_store.get_annotation_objects() 148 data_parameters.update(annotation_objects) 149 # load or create the annotation store 150 ok = False 151 if annotation_store is not None: 152 self.annotation_store = annotation_store 153 ok = True 154 elif ( 155 (annotation_key_objects is not None) 156 and mask is None 157 and not recompute_annotation 158 ): 159 if len(annotation_key_objects) > 0: 160 try: 161 self.annotation_store = self._load_annotation_store( 162 annotation_type, annotation_key_objects 163 ) 164 ok = True 165 except: 166 warnings.warn("Loading annotation store from key objects failed") 167 if not ok: 168 self.annotation_store = self._get_annotation_store( 169 annotation_type, deepcopy(data_parameters) 170 ) 171 to_remove = self.annotation_store.filtered_indices() 172 if len(to_remove) > 0: 173 print( 174 f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples" 175 ) 176 if len(to_remove) == len(self.annotation_store) and len(to_remove) > 0: 177 raise ValueError("All samples were filtered out!") 178 179 if len(self.input_store) == len(self.annotation_store): 180 self.input_store.remove(to_remove) 181 self.annotation_store.remove(to_remove) 182 self.input_indices = list(range(len(self.input_store))) 183 self.annotation_indices = list(range(len(self.input_store))) 184 self.indices = list(range(len(self.input_store))) 185 186 def __getitem__(self, item: int) -> Dict: 187 idx = self._get_idx(item) 188 input = deepcopy(self.input_store[idx]) 189 target = self.annotation_store[idx] 190 tag = self.input_store.get_tag(idx) 191 ssl_inputs, ssl_targets = self._get_SSL_targets(input) 192 batch = {"input": input} 193 for name, x in zip( 194 ["target", "ssl_inputs", "ssl_targets", "tag"], 195 [target, ssl_inputs, ssl_targets, tag], 196 ): 197 if x is not None: 198 batch[name] = x 199 batch["index"] = idx 200 if self.stats is not None: 201 for key in batch["input"].keys(): 202 key_name = key.split("---")[0] 203 if key_name in self.stats: 204 batch["input"][key][:, batch["input"][key].sum(0) != 0] = ( 205 (batch["input"][key] - self.stats[key_name]["mean"]) 206 / (self.stats[key_name]["std"] + 1e-7) 207 )[:, batch["input"][key].sum(0) != 0] 208 return batch 209 210 def __len__(self) -> int: 211 return len(self.indices) 212 # if self.annotation_type != "none": 213 # return self.annotation_store.get_len(return_unlabeled=self.return_unlabeled) 214 # else: 215 # return len(self.input_store) 216 217 def get_tags(self) -> List: 218 """Get a list of all meta tags. 219 220 Returns 221 ------- 222 tags: List 223 a list of unique meta tag values 224 225 """ 226 return self.input_store.get_tags() 227 228 def save(self, save_path: str) -> None: 229 """Save the dictionary. 230 231 Parameters 232 ---------- 233 save_path : str 234 the path where the pickled file will be stored 235 236 """ 237 input_obj = self.input_store.key_objects() 238 annotation_obj = self.annotation_store.key_objects() 239 with open(save_path, "wb") as f: 240 pickle.dump((input_obj, annotation_obj), f) 241 242 def to_ram(self) -> None: 243 """Transfer the dataset to RAM.""" 244 self.input_store.to_ram() 245 self.annotation_store.to_ram() 246 247 def generate_full_length_gt(self) -> Dict: 248 """Generate full-length ground truth from the annotations. 249 250 Returns 251 ------- 252 full_length_gt : dict 253 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 254 values are the ground truth labels 255 256 """ 257 if self.annotation_class() == "exclusive_classification": 258 gt = torch.zeros((len(self), self.len_segment())) 259 else: 260 gt = torch.zeros( 261 (len(self), len(self.behaviors_dict()), self.len_segment()) 262 ) 263 for i in range(len(self)): 264 gt[i] = self.annotation_store[i] 265 return self.generate_full_length_prediction(gt) 266 267 def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict: 268 """Map predictions for the equal-length pieces to predictions for the original data. 269 270 Probabilities are averaged over predictions on overlapping intervals. 271 272 Parameters 273 ---------- 274 predicted: torch.Tensor 275 a tensor of predicted probabilities of shape `(N, #classes, #frames)` 276 277 Returns 278 ------- 279 full_length_prediction : dict 280 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are 281 averaged probability tensors 282 283 """ 284 result = defaultdict(lambda: {}) 285 counter = defaultdict(lambda: {}) 286 coordinates = self.input_store.get_original_coordinates() 287 for coords, prediction in zip(coordinates, predicted): 288 l = self.input_store.get_clip_length_from_coords(coords) 289 video_name = self.input_store.get_video_id(coords) 290 clip_id = self.input_store.get_clip_id(coords) 291 start, end = self.input_store.get_clip_start_end(coords) 292 if clip_id not in result[video_name].keys(): 293 result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 294 counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 295 result[video_name][clip_id][..., start:end] += ( 296 prediction.squeeze()[..., : end - start].detach().cpu() 297 ) 298 counter[video_name][clip_id][..., start:end] += 1 299 for video_name in result: 300 for clip_id in result[video_name]: 301 result[video_name][clip_id] /= counter[video_name][clip_id] 302 result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100 303 result = dict(result) 304 return result 305 306 def find_valleys( 307 self, 308 predicted: Union[torch.Tensor, Dict], 309 threshold: float = 0.5, 310 min_frames: int = 0, 311 visibility_min_score: float = 0, 312 visibility_min_frac: float = 0, 313 main_class: int = 1, 314 low: bool = True, 315 predicted_error: torch.Tensor = None, 316 error_threshold: float = 0.5, 317 hysteresis: bool = False, 318 threshold_diff: float = None, 319 min_frames_error: int = None, 320 smooth_interval: int = 1, 321 cut_annotated: bool = False, 322 ) -> Dict: 323 """Find the intervals where the probability of a certain class is below or above a certain hard_threshold. 324 325 Parameters 326 ---------- 327 predicted : torch.Tensor | dict 328 either a tensor of predictions for the data prompts or the output of 329 `BehaviorDataset.generate_full_length_prediction` 330 threshold : float, default 0.5 331 the main hard_threshold 332 min_frames : int, default 0 333 the minimum length of the intervals 334 visibility_min_score : float, default 0 335 the minimum visibility score in the intervals 336 visibility_min_frac : float, default 0 337 fraction of the interval that has to have the visibility score larger than visibility_score_thr 338 main_class : int, default 1 339 the index of the class the function is inspecting 340 low : bool, default True 341 if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above 342 predicted_error : torch.Tensor, optional 343 a tensor of error predictions for the data prompts 344 error_threshold : float, default 0.5 345 maximum possible probability of error at the intervals 346 hysteresis: bool, default False 347 if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff 348 threshold_diff: float, optional 349 the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes 350 min_frames_error: int, optional 351 if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames 352 smooth_interval: int, default 1 353 the number of frames to smooth the predictions over 354 cut_annotated: bool, default False 355 if `True`, annotated intervals will be cut out of the predicted intervals 356 357 Returns 358 ------- 359 valleys : dict 360 a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals 361 362 """ 363 result = defaultdict(lambda: []) 364 if type(predicted) is not dict: 365 predicted = self.generate_full_length_prediction(predicted) 366 if predicted_error is not None: 367 predicted_error = self.generate_full_length_prediction(predicted_error) 368 elif min_frames_error is not None and min_frames_error != 0: 369 # warnings.warn( 370 # f"The min_frames_error parameter is set to {min_frames_error} but no error prediction " 371 # f"is given! Setting min_frames_error to 0." 372 # ) 373 min_frames_error = 0 374 if low and hysteresis and threshold_diff is None: 375 raise ValueError( 376 "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff." 377 ) 378 if cut_annotated: 379 masked_intervals_dict = self.get_annotated_intervals() 380 else: 381 masked_intervals_dict = None 382 print("Valleys found:") 383 for v_id in predicted: 384 for clip_id in predicted[v_id].keys(): 385 if predicted_error is not None: 386 error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold 387 if min_frames_error is not None: 388 output, indices, counts = torch.unique_consecutive( 389 error_mask, return_inverse=True, return_counts=True 390 ) 391 wrong_indices = torch.where( 392 output * (counts < min_frames_error) 393 )[0] 394 if len(wrong_indices) > 0: 395 for i in wrong_indices: 396 error_mask[indices == i] = False 397 else: 398 error_mask = None 399 if masked_intervals_dict is not None: 400 masked_intervals = masked_intervals_dict[v_id][clip_id] 401 else: 402 masked_intervals = None 403 if not hysteresis: 404 res_indices_start, res_indices_end = apply_threshold( 405 predicted[v_id][clip_id][main_class, :], 406 threshold, 407 low, 408 error_mask, 409 min_frames, 410 smooth_interval, 411 masked_intervals, 412 ) 413 elif threshold_diff is not None: 414 if low: 415 soft_threshold = threshold + threshold_diff 416 else: 417 soft_threshold = threshold - threshold_diff 418 res_indices_start, res_indices_end = apply_threshold_hysteresis( 419 predicted[v_id][clip_id][main_class, :], 420 soft_threshold, 421 threshold, 422 low, 423 error_mask, 424 min_frames, 425 smooth_interval, 426 masked_intervals, 427 ) 428 else: 429 res_indices_start, res_indices_end = apply_threshold_max( 430 predicted[v_id][clip_id], 431 threshold, 432 main_class, 433 error_mask, 434 min_frames, 435 smooth_interval, 436 masked_intervals, 437 ) 438 start = self.input_store.get_clip_start(v_id, clip_id) 439 result[v_id] += [ 440 [i + start, j + start, clip_id] 441 for i, j in zip(res_indices_start, res_indices_end) 442 if self.input_store.get_visibility( 443 v_id, clip_id, i, j, visibility_min_score 444 ) 445 > visibility_min_frac 446 ] 447 result[v_id] = sorted(result[v_id]) 448 print(f" {v_id}: {len(result[v_id])}") 449 return dict(result) 450 451 def valleys_union(self, valleys_list) -> Dict: 452 """Find the intersection of two valleys dictionaries. 453 454 Parameters 455 ---------- 456 valleys_list : list 457 a list of valleys dictionaries 458 459 Returns 460 ------- 461 intersection : dict 462 a new valleys dictionary with the intersection of the input intervals 463 464 """ 465 valleys_list = [x for x in valleys_list if x is not None] 466 if len(valleys_list) == 1: 467 return valleys_list[0] 468 elif len(valleys_list) == 0: 469 return {} 470 union = {} 471 keys_list = [set(valleys.keys()) for valleys in valleys_list] 472 keys = set.union(*keys_list) 473 for v_id in keys: 474 res = [] 475 clips_list = [ 476 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 477 ] 478 clips = set.union(*clips_list) 479 for clip_id in clips: 480 clip_intervals = [ 481 x 482 for valleys in valleys_list 483 for x in valleys[v_id] 484 if x[-1] == clip_id 485 ] 486 v_len = self.input_store.get_clip_length(v_id, clip_id) 487 arr = torch.zeros(v_len) 488 for start, end, _ in clip_intervals: 489 arr[start:end] += 1 490 output, indices, counts = torch.unique_consecutive( 491 arr > 0, return_inverse=True, return_counts=True 492 ) 493 long_indices = torch.where(output)[0] 494 res += [ 495 ( 496 (indices == i).nonzero(as_tuple=True)[0][0].item(), 497 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 498 clip_id, 499 ) 500 for i in long_indices 501 ] 502 union[v_id] = res 503 return union 504 505 def valleys_intersection(self, valleys_list) -> Dict: 506 """Find the intersection of two valleys dictionaries. 507 508 Parameters 509 ---------- 510 valleys_list : list 511 a list of valleys dictionaries 512 513 Returns 514 ------- 515 intersection : dict 516 a new valleys dictionary with the intersection of the input intervals 517 518 """ 519 valleys_list = [x for x in valleys_list if x is not None] 520 if len(valleys_list) == 1: 521 return valleys_list[0] 522 elif len(valleys_list) == 0: 523 return {} 524 intersection = {} 525 keys_list = [set(valleys.keys()) for valleys in valleys_list] 526 keys = set.intersection(*keys_list) 527 for v_id in keys: 528 res = [] 529 clips_list = [ 530 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 531 ] 532 clips = set.intersection(*clips_list) 533 for clip_id in clips: 534 clip_intervals = [ 535 x 536 for valleys in valleys_list 537 for x in valleys[v_id] 538 if x[-1] == clip_id 539 ] 540 v_len = self.input_store.get_clip_length(v_id, clip_id) 541 arr = torch.zeros(v_len) 542 for start, end, _ in clip_intervals: 543 arr[start:end] += 1 544 output, indices, counts = torch.unique_consecutive( 545 arr, return_inverse=True, return_counts=True 546 ) 547 long_indices = torch.where(output == 2)[0] 548 res += [ 549 ( 550 (indices == i).nonzero(as_tuple=True)[0][0].item(), 551 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 552 clip_id, 553 ) 554 for i in long_indices 555 ] 556 intersection[v_id] = res 557 return intersection 558 559 def partition_train_test_val( 560 self, 561 use_test: float = 0, 562 split_path: str = None, 563 method: str = "random", 564 val_frac: float = 0, 565 test_frac: float = 0, 566 save_split: bool = False, 567 normalize: bool = False, 568 skip_normalization_keys: List = None, 569 stats: Dict = None, 570 ) -> Tuple: 571 """Partition the dataset into three new datasets. 572 573 Parameters 574 ---------- 575 use_test : float, default 0 576 The fraction of the test dataset to be used in training without labels 577 split_path : str, optional 578 The path to load the split information from (if `'file'` method is used) and to save it to 579 (if `'save_split'` is `True`) 580 method : {'random', 'random:test-from-name', 'random:test-from-name:{name}', 581 'val-from-name:{val_name}:test-from-name:{test_name}', 582 'random:equalize:segments', 'random:equalize:videos', 583 'folders', 'time', 'time:strict', 'file'} 584 The partitioning method: 585 - `'random'`: sort videos into subsets randomly, 586 - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation 587 subsets randomly and create 588 the test subset from the video ids that start with a speific substring (`'test'` by default, or `name` 589 if provided), 590 - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but 591 making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain 592 occurrences of the class get into the validation subset and `0.8 * test_frac` get into the test subset; 593 this in ensured for all classes in order of increasing number of occurrences until the validation and 594 test subsets are full 595 - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test 596 subsets from the video ids that start with specific substrings (`val_name` for validation 597 and `test_name` for test) and sort all other videos into the training subset 598 - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets, 599 - `'time'`: split each video into training, validation and test subsequences, 600 - `'time:strict'`: split each video into validation, test and training subsequences 601 and throw out the last segments in validation and test (to get rid of overlaps), 602 - `'file'`: split according to a split file. 603 val_frac : float, default 0 604 The fraction of the dataset to be used in validation 605 test_frac : float, default 0 606 The fraction of the dataset to be used in test 607 save_split : bool, default False 608 Save a split file if True 609 normalize : bool, default False 610 Normalize the dataset if `True` 611 skip_normalization_keys : list, optional 612 A list of keys to skip normalization for 613 stats : dict, optional 614 A dictionary of (pre-computed) statistics to use for normalization 615 616 Returns 617 ------- 618 train_dataset : BehaviorDataset 619 train dataset 620 val_dataset : BehaviorDataset 621 validation dataset 622 test_dataset : BehaviorDataset 623 test dataset 624 625 """ 626 train_indices, test_indices, val_indices = self._partition_indices( 627 split_path=split_path, 628 method=method, 629 val_frac=val_frac, 630 test_frac=test_frac, 631 save_split=save_split, 632 ) 633 ssl_indices = None 634 partition_method = method.split(":") 635 if ( 636 partition_method[0] in ("leave-one-in", "leave-n-in") 637 and len(partition_method) > 1 638 and partition_method[2] == "val-for-ssl" 639 ): 640 print("Using validation samples for SSL!") 641 ssl_indices = val_indices 642 643 val_dataset = self._create_new_dataset(val_indices) 644 test_dataset = self._create_new_dataset(test_indices) 645 train_dataset = self._create_new_dataset(train_indices, ssl_indices=ssl_indices) 646 647 train_classes = train_dataset.count_classes() 648 val_classes = val_dataset.count_classes() 649 test_classes = test_dataset.count_classes() 650 print("Number of samples:") 651 print(f" validation:") 652 print(f" {[f'{k}: {val_classes[k]}' for k in sorted(val_classes.keys())]}") 653 print(f" training:") 654 print(f" {[f'{k}: {train_classes[k]}' for k in sorted(train_classes.keys())]}") 655 print(f" test:") 656 print(f" {[f'{k}: {test_classes[k]}' for k in sorted(test_classes.keys())]}") 657 if normalize: 658 if stats is None: 659 print("Computing normalization statistics...") 660 stats = train_dataset.get_normalization_stats(skip_normalization_keys) 661 else: 662 print("Setting loaded normalization statistics...") 663 train_dataset.set_normalization_stats(stats) 664 val_dataset.set_normalization_stats(stats) 665 test_dataset.set_normalization_stats(stats) 666 return train_dataset, test_dataset, val_dataset 667 668 def class_weights(self, proportional=False) -> List: 669 """Calculate class weights in inverse proportion to number of samples. 670 671 Parameters 672 ---------- 673 proportional : bool, default False 674 If `True`, the weights are proportional to the number of samples in the most common class 675 676 Returns 677 ------- 678 weights: list 679 a list of class weights 680 681 """ 682 items = sorted( 683 [ 684 (k, v) 685 for k, v in self.annotation_store.count_classes().items() 686 if k != -100 687 ] 688 ) 689 if self.annotation_store.annotation_class() == "exclusive_classification": 690 if not proportional: 691 numerator = len(self.annotation_store) 692 else: 693 numerator = max([x[1] for x in items]) 694 weights = [numerator / (v + 1e-7) for _, v in items] 695 else: 696 items_zero = sorted( 697 [ 698 (k, v) 699 for k, v in self.annotation_store.count_classes(zeros=True).items() 700 if k != -100 701 ] 702 ) 703 if not proportional: 704 numerators = defaultdict(lambda: len(self.annotation_store)) 705 else: 706 numerators = { 707 item_one[0]: max(item_one[1], item_zero[1]) 708 for item_one, item_zero in zip(items, items_zero) 709 } 710 weights = {} 711 weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero] 712 weights[1] = [numerators[k] / (v + 1e-7) for k, v in items] 713 return weights 714 715 def _boundary_class_weight(self): 716 """Calculate the weight of the boundary class. 717 718 Returns 719 ------- 720 weight: float 721 the weight of the boundary class 722 723 """ 724 if self.annotation_type != "none": 725 f = self.annotation_store.data.flatten() 726 _, inv = torch.unique_consecutive(f, return_inverse=True) 727 boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape( 728 self.annotation_store.data.shape 729 ) 730 boundary[..., 0] = 0 731 cnt = Counter(boundary.flatten().numpy()) 732 return cnt[1] / cnt[0] 733 else: 734 return 0 735 736 def count_classes(self, bouts: bool = False) -> Dict: 737 """Get a class counter dictionary. 738 739 Parameters 740 ---------- 741 bouts : bool, default False 742 if `True`, instead of frame counts segment counts are returned 743 744 Returns 745 ------- 746 count_dictionary : dict 747 a dictionary with class indices as keys and frame or bout counts as values 748 749 """ 750 return self.annotation_store.count_classes(bouts=bouts) 751 752 def behaviors_dict(self) -> Dict: 753 """Get a behavior dictionary. 754 755 Returns 756 ------- 757 dict 758 behavior dictionary 759 760 """ 761 return self.annotation_store.behaviors_dict() 762 763 def bodyparts_order(self) -> List: 764 """Get the order of bodyparts. 765 766 Returns 767 ------- 768 bodyparts : List 769 a list of bodyparts 770 771 """ 772 try: 773 return self.input_store.get_bodyparts() 774 except: 775 raise RuntimeError( 776 f"The {self.input_type} input store does not have bodyparts implemented!" 777 ) 778 779 def features_shape(self) -> Dict: 780 """Get the shapes of the input features. 781 782 Returns 783 ------- 784 shapes : Dict 785 a dictionary with the shapes of the features 786 787 """ 788 sample = self.input_store[0] 789 shapes = {k: v.shape for k, v in sample.items()} 790 # for key, value in shapes.items(): 791 # print(f'{key}: {value}') 792 return shapes 793 794 def num_classes(self) -> int: 795 """Get the number of classes in the data. 796 797 Returns 798 ------- 799 num_classes : int 800 the number of classes 801 802 """ 803 return len(self.annotation_store.behaviors_dict()) 804 805 def len_segment(self) -> int: 806 """Get the segment length in the data. 807 808 Returns 809 ------- 810 len_segment : int 811 the segment length 812 813 """ 814 sample = self.input_store[0] 815 key = list(sample.keys())[0] 816 return sample[key].shape[-1] 817 818 def set_ssl_transformations(self, ssl_transformations: List) -> None: 819 """Set new SSL transformations. 820 821 Parameters 822 ---------- 823 ssl_transformations : list 824 a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets 825 lists 826 827 """ 828 self.ssl_transformations = ssl_transformations 829 830 @classmethod 831 def new(cls, *args, **kwargs): 832 """Create a new object of the same class. 833 834 Parameters 835 ---------- 836 args : list 837 arguments for the constructor 838 kwargs : dict 839 keyword arguments for the constructor 840 841 Returns 842 ------- 843 new_instance: BehaviorDataset 844 a new instance of the same class 845 846 """ 847 return cls(*args, **kwargs) 848 849 @classmethod 850 def get_parameters(cls, data_type: str, annotation_type: str) -> List: 851 """Get parameters necessary for initialization. 852 853 Parameters 854 ---------- 855 data_type : str 856 the data type 857 annotation_type : str 858 the annotation type 859 860 Returns 861 ------- 862 parameters : list 863 a list of parameters 864 865 """ 866 input_features = options.input_stores[data_type].get_parameters() 867 annotation_features = options.annotation_stores[ 868 annotation_type 869 ].get_parameters() 870 self_features = inspect.getfullargspec(cls.__init__).args 871 return self_features + input_features + annotation_features 872 873 @staticmethod 874 def data_types() -> List: 875 """List available data types. 876 877 Returns 878 ------- 879 data_types : list 880 available data types 881 882 """ 883 return list(options.input_stores.keys()) 884 885 @staticmethod 886 def annotation_types() -> List: 887 """List available annotation types. 888 889 Returns 890 ------- 891 annotation_types : list 892 available annotation types 893 894 """ 895 return list(options.annotation_stores.keys()) 896 897 def _get_SSL_targets(self, input: Dict) -> Tuple[List, List]: 898 """Get the SSL inputs and targets from a sample dictionary.""" 899 ssl_inputs = [] 900 ssl_targets = [] 901 for transform in self.ssl_transformations: 902 ssl_input, ssl_target = transform(copy(input)) 903 ssl_inputs.append(ssl_input) 904 ssl_targets.append(ssl_target) 905 return ssl_inputs, ssl_targets 906 907 def _create_new_dataset(self, indices: List, ssl_indices: List = None): 908 """Create a subsample of the dataset, with samples at ssl_indices losing the annotation.""" 909 if ssl_indices is None: 910 ssl_indices = [] 911 input_store = self.input_store.create_subsample(indices, ssl_indices) 912 annotation_store = self.annotation_store.create_subsample(indices, ssl_indices) 913 new = self.new( 914 data_type=self.input_type, 915 annotation_type=self.annotation_type, 916 ssl_transformations=self.ssl_transformations, 917 annotation_store=annotation_store, 918 input_store=input_store, 919 ids=list(indices) + list(ssl_indices), 920 recompute_annotation=False, 921 ) 922 return new 923 924 def _load_input_store(self, data_type: str, key_objects: Tuple) -> InputStore: 925 """Load input store from key objects.""" 926 input_store = options.input_stores[data_type](key_objects=key_objects) 927 return input_store 928 929 def _load_annotation_store( 930 self, annotation_type: str, key_objects: Tuple 931 ) -> BehaviorStore: 932 """Load annotation store from key objects.""" 933 annotation_store = options.annotation_stores[annotation_type]( 934 key_objects=key_objects 935 ) 936 return annotation_store 937 938 def _get_input_store(self, data_type: str, data_parameters: Dict) -> InputStore: 939 """Create input store from parameters.""" 940 data_parameters["key_objects"] = None 941 input_store = options.input_stores[data_type](**data_parameters) 942 return input_store 943 944 def _get_annotation_store( 945 self, annotation_type: str, data_parameters: Dict 946 ) -> BehaviorStore: 947 """Create annotation store from parameters.""" 948 annotation_store = options.annotation_stores[annotation_type](**data_parameters) 949 return annotation_store 950 951 def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None: 952 """Set the parameters that change the subset that is returned at `__getitem__`. 953 954 Parameters 955 ---------- 956 unlabeled : bool 957 a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and 958 all if `None` 959 tag : int 960 if not `None`, only samples with this meta tag will be returned 961 962 """ 963 if unlabeled != self.return_unlabeled: 964 self.annotation_indices = self.annotation_store.get_indices(unlabeled) 965 self.return_unlabeled = unlabeled 966 if tag != self.tag: 967 self.input_indices = self.input_store.get_indices(tag) 968 self.tag = tag 969 self.indices = [x for x in self.annotation_indices if x in self.input_indices] 970 971 def _get_idx(self, index: int) -> int: 972 """Get index in full dataset.""" 973 return self.indices[index] 974 975 # return self.annotation_store.get_idx( 976 # index, return_unlabeled=self.return_unlabeled 977 # ) 978 979 def _partition_indices( 980 self, 981 split_path: str = None, 982 method: str = "random", 983 val_frac: float = 0, 984 test_frac: float = 0, 985 save_split: bool = False, 986 ) -> Tuple[List, List, List]: 987 """Partition indices into train, validation, test subsets.""" 988 if self.mask is not None: 989 val_indices = self.mask["val_ids"] 990 train_indices = [x for x in range(len(self)) if x not in val_indices] 991 test_indices = [] 992 elif method == "random": 993 videos = np.array(self.input_store.get_video_id_order()) 994 all_videos = list(set(videos)) 995 if len(all_videos) == 1: 996 warnings.warn( 997 "There is only one video in the dataset, so train/val/test split is done on segments; " 998 'that might lead to overlaps, please consider using "time" or "time:strict" as the ' 999 "partitioning method instead" 1000 ) 1001 # Quick fix for single video: the problem with this is that the segments can overlap 1002 # length = int(self.input_store.get_original_coordinates()[-1][1]) # number of segments 1003 length = len(self.input_store.get_original_coordinates()) 1004 val_len = int(val_frac * length) 1005 test_len = int(test_frac * length) 1006 all_indices = np.random.choice(np.arange(length), length, replace=False) 1007 val_indices = all_indices[:val_len] 1008 test_indices = all_indices[val_len : val_len + test_len] 1009 train_indices = all_indices[val_len + test_len :] 1010 coords = self.input_store.get_original_coordinates() 1011 if save_split: 1012 self._save_partition( 1013 coords[train_indices], 1014 coords[val_indices], 1015 coords[test_indices], 1016 split_path, 1017 coords=True, 1018 ) 1019 else: 1020 length = len(all_videos) 1021 val_len = int(val_frac * length) 1022 test_len = int(test_frac * length) 1023 validation = all_videos[:val_len] 1024 test = all_videos[val_len : val_len + test_len] 1025 training = all_videos[val_len + test_len :] 1026 train_indices = np.where(np.isin(videos, training))[0] 1027 val_indices = np.where(np.isin(videos, validation))[0] 1028 test_indices = np.where(np.isin(videos, test))[0] 1029 if save_split: 1030 self._save_partition(training, validation, test, split_path) 1031 elif method.startswith("random:equalize"): 1032 counter = self.count_classes() 1033 counter = sorted(list([(v, k) for k, v in counter.items()])) 1034 classes = [x[1] for x in counter] 1035 indicator = {c: [] for c in classes} 1036 if method.endswith("videos"): 1037 videos = np.array(self.input_store.get_video_id_order()) 1038 all_videos = list(set(videos)) 1039 total_len = len(all_videos) 1040 for video_id in all_videos: 1041 video_coords = np.where(videos == video_id)[0] 1042 ann = torch.cat( 1043 [self.annotation_store[i] for i in video_coords], dim=-1 1044 ) 1045 for c in classes: 1046 if self.annotation_class() == "nonexclusive_classification": 1047 indicator[c].append(torch.sum(ann[c] == 1) > 0) 1048 elif self.annotation_class() == "exclusive_classification": 1049 indicator[c].append(torch.sum(ann == c) > 0) 1050 else: 1051 raise ValueError( 1052 f"The random:equalize partition method is not implemented" 1053 f"for the {self.annotation_class()} annotation class" 1054 ) 1055 elif method.endswith("segments"): 1056 total_len = len(self) 1057 for ann in self.annotation_store: 1058 for c in classes: 1059 if self.annotation_class() == "nonexclusive_classification": 1060 indicator[c].append(torch.sum(ann[c] == 1) > 0) 1061 elif self.annotation_class() == "exclusive_classification": 1062 indicator[c].append(torch.sum(ann == c) > 0) 1063 else: 1064 raise ValueError( 1065 f"The random:equalize partition method is not implemented" 1066 f"for the {self.annotation_class()} annotation class" 1067 ) 1068 else: 1069 values = [] 1070 for v in options.partition_methods.values(): 1071 values += v 1072 raise ValueError( 1073 f"The {method} partition method is not recognized; please choose from {values}" 1074 ) 1075 val_indices = [] 1076 test_indices = [] 1077 for c in classes: 1078 indicator[c] = np.array(indicator[c]) 1079 ind = np.where(indicator[c])[0] 1080 np.random.shuffle(ind) 1081 c_sum = len(ind) 1082 in_val = np.sum(indicator[c][val_indices]) 1083 in_test = np.sum(indicator[c][test_indices]) 1084 while ( 1085 len(val_indices) < val_frac * total_len 1086 and in_val < val_frac * c_sum * 0.8 1087 ): 1088 first, ind = ind[0], ind[1:] 1089 val_indices = list(set(val_indices).union({first})) 1090 in_val = np.sum(indicator[c][val_indices]) 1091 while ( 1092 len(test_indices) < test_frac * total_len 1093 and in_test < test_frac * c_sum * 0.8 1094 ): 1095 first, ind = ind[0], ind[1:] 1096 test_indices = list(set(test_indices).union({first})) 1097 in_test = np.sum(indicator[c][test_indices]) 1098 if len(val_indices) < int(val_frac * total_len): 1099 left_val = int(val_frac * total_len) - len(val_indices) 1100 else: 1101 left_val = 0 1102 if len(test_indices) < int(test_frac * total_len): 1103 left_test = int(test_frac * total_len) - len(test_indices) 1104 else: 1105 left_test = 0 1106 indicator = np.ones(total_len) 1107 indicator[val_indices] = 0 1108 indicator[test_indices] = 0 1109 ind = np.where(indicator)[0] 1110 np.random.shuffle(ind) 1111 val_indices += list(ind[:left_val]) 1112 test_indices += list(ind[left_val : left_val + left_test]) 1113 train_indices = list(ind[left_val + left_test :]) 1114 if save_split: 1115 if method.endswith("segments"): 1116 coords = self.input_store.get_original_coordinates() 1117 self._save_partition( 1118 coords[train_indices], 1119 coords[val_indices], 1120 coords[test_indices], 1121 coords[split_path], 1122 coords=True, 1123 ) 1124 else: 1125 all_videos = np.array(all_videos) 1126 validation = all_videos[val_indices] 1127 test = all_videos[test_indices] 1128 training = all_videos[train_indices] 1129 self._save_partition(training, validation, test, split_path) 1130 elif method.startswith("random:test-from-name"): 1131 split = method.split(":") 1132 if len(split) > 2: 1133 test_name = split[-1] 1134 else: 1135 test_name = "test" 1136 videos = np.array(self.input_store.get_video_id_order()) 1137 all_videos = list(set(videos)) 1138 test = [] 1139 train_videos = [] 1140 for x in all_videos: 1141 if x.startswith(test_name): 1142 test.append(x) 1143 else: 1144 train_videos.append(x) 1145 length = len(train_videos) 1146 val_len = int(val_frac * length) 1147 validation = train_videos[:val_len] 1148 training = train_videos[val_len:] 1149 train_indices = np.where(np.isin(videos, training))[0] 1150 val_indices = np.where(np.isin(videos, validation))[0] 1151 test_indices = np.where(np.isin(videos, test))[0] 1152 if save_split: 1153 self._save_partition(training, validation, test, split_path) 1154 elif method.startswith("val-from-name"): 1155 split = method.split(":") 1156 if split[2] != "test-from-name": 1157 raise ValueError( 1158 f"The {method} partition method is not recognized, please choose from {options.partition_methods}" 1159 ) 1160 val_name = split[1] 1161 test_name = split[-1] 1162 videos = np.array(self.input_store.get_video_id_order()) 1163 all_videos = list(set(videos)) 1164 test = [] 1165 validation = [] 1166 training = [] 1167 for x in all_videos: 1168 if x.startswith(test_name): 1169 test.append(x) 1170 elif x.startswith(val_name): 1171 validation.append(x) 1172 else: 1173 training.append(x) 1174 train_indices = np.where(np.isin(videos, training))[0] 1175 val_indices = np.where(np.isin(videos, validation))[0] 1176 test_indices = np.where(np.isin(videos, test))[0] 1177 elif method == "folders": 1178 folders = np.array(self.input_store.get_folder_order()) 1179 videos = np.array(self.input_store.get_video_id_order()) 1180 train_indices = np.where(np.isin(folders, ["training", "train"]))[0] 1181 if np.sum(np.isin(folders, ["validation", "val"])) > 0: 1182 val_indices = np.where(np.isin(folders, ["validation", "val"]))[0] 1183 else: 1184 train_videos = list(set(videos[train_indices])) 1185 val_len = int(val_frac * len(train_videos)) 1186 validation = train_videos[:val_len] 1187 training = train_videos[val_len:] 1188 train_indices = np.where(np.isin(videos, training))[0] 1189 val_indices = np.where(np.isin(videos, validation))[0] 1190 test_indices = np.where(folders == "test")[0] 1191 if save_split: 1192 self._save_partition( 1193 list(set(videos[train_indices])), 1194 list(set(videos[val_indices])), 1195 list(set(videos[test_indices])), 1196 split_path, 1197 ) 1198 elif method.startswith("leave-one-out"): 1199 n = int(method.split(":")[-1]) 1200 videos = np.array(self.input_store.get_video_id_order()) 1201 all_videos = sorted(list(set(videos))) 1202 print(len(all_videos)) 1203 validation = [all_videos.pop(n)] 1204 training = all_videos 1205 train_indices = np.where(np.isin(videos, training))[0] 1206 val_indices = np.where(np.isin(videos, validation))[0] 1207 test_indices = np.array([]) 1208 elif method.startswith("leave-one-in"): 1209 n = int(method.split(":")[1]) 1210 videos = np.array(self.input_store.get_video_id_order()) 1211 all_videos = sorted(list(set(videos))) 1212 training = [all_videos.pop(n)] 1213 validation = all_videos 1214 train_indices = np.where(np.isin(videos, training))[0] 1215 val_indices = np.where(np.isin(videos, validation))[0] 1216 test_indices = np.array([]) 1217 elif method.startswith("leave-n-in"): 1218 train_idx = [int(i) for i in method.split(":")[1].split(",")] 1219 videos = np.array(self.input_store.get_video_id_order()) 1220 all_videos = sorted(list(set(videos))) 1221 training = [v for i, v in enumerate(all_videos) if i in train_idx] 1222 validation = [v for i, v in enumerate(all_videos) if i not in train_idx] 1223 train_indices = np.where(np.isin(videos, training))[0] 1224 val_indices = np.where(np.isin(videos, validation))[0] 1225 test_indices = np.array([]) 1226 elif method.startswith("time"): 1227 if method.endswith("strict"): 1228 len_segment = self.len_segment() 1229 step = self.input_store.step 1230 num_removed = len_segment // step 1231 else: 1232 num_removed = 0 1233 videos = np.array(self.input_store.get_video_id_order()) 1234 all_videos = set(videos) 1235 train_indices = [] 1236 val_indices = [] 1237 test_indices = [] 1238 start = 0 1239 if len(method.split(":")) > 1 and method.split(":")[1] == "start-from": 1240 start = float(method.split(":")[2]) 1241 for video_id in all_videos: 1242 video_indices = np.where(videos == video_id)[0] 1243 val_len = int(val_frac * len(video_indices)) 1244 test_len = int(test_frac * len(video_indices)) 1245 start_pos = int(start * len(video_indices)) 1246 all_ind = np.ones(len(video_indices)) 1247 val_indices += list(video_indices[start_pos : start_pos + val_len]) 1248 all_ind[start_pos : start_pos + val_len] = 0 1249 if start_pos + val_len > len(video_indices): 1250 p = start_pos + val_len - len(video_indices) 1251 val_indices += list(video_indices[:p]) 1252 all_ind[:p] = 0 1253 else: 1254 p = start_pos + val_len 1255 test_indices += list(video_indices[p : p + test_len]) 1256 all_ind[p : p + test_len] = 0 1257 if p + test_len > len(video_indices): 1258 p = test_len + p - len(video_indices) 1259 test_indices += list(video_indices[:p]) 1260 all_ind[:p] = 0 1261 train_indices += list(video_indices[all_ind > 0]) 1262 for _ in range(num_removed): 1263 if len(val_indices) > 0: 1264 val_indices.pop(-1) 1265 if len(test_indices) > 0: 1266 test_indices.pop(-1) 1267 if start > 0 and len(train_indices) > 0: 1268 train_indices.pop(-1) 1269 elif method == "file": 1270 if split_path is None: 1271 raise ValueError( 1272 'You need to either set split_path or change partition method ("file" requires a file)' 1273 ) 1274 active_list = None 1275 training, validation, test = [], [], [] 1276 with open(split_path) as f: 1277 for line in f.readlines(): 1278 if line.startswith("Train"): 1279 active_list = training 1280 elif line.startswith("Valid"): 1281 active_list = validation 1282 elif line.startswith("Test"): 1283 active_list = test 1284 else: 1285 stripped_line = line.rstrip(",\n ") 1286 if stripped_line == "": 1287 continue 1288 if ", " in stripped_line: 1289 active_list += stripped_line.split(", ") 1290 else: 1291 active_list.append(stripped_line) 1292 all_lines = training + validation + test 1293 if len(all_lines[0].split("---")) == 3: 1294 entry_type = "coords" 1295 else: 1296 entry_type = "videos" 1297 1298 if entry_type == "videos": 1299 videos = np.array(self.input_store.get_video_id_order()) 1300 val_indices = np.where(np.isin(videos, validation))[0] 1301 test_indices = np.where(np.isin(videos, test))[0] 1302 train_indices = np.where(np.isin(videos, training))[0] 1303 elif entry_type == "coords": 1304 coords = self.input_store.get_original_coordinates() 1305 video_ids = self.input_store.get_video_id_order() 1306 clip_ids = [self.input_store.get_clip_id(coord) for coord in coords] 1307 starts, ends = zip( 1308 *[self.input_store.get_clip_start_end(coord) for coord in coords] 1309 ) 1310 coords = np.array( 1311 [ 1312 f"{video_id}---{clip_id}---{start}-{end}" 1313 for video_id, clip_id, start, end in zip( 1314 video_ids, clip_ids, starts, ends 1315 ) 1316 ] 1317 ) 1318 val_indices = np.where(np.isin(coords, validation))[0] 1319 test_indices = np.where(np.isin(coords, test))[0] 1320 train_indices = np.where(np.isin(coords, training))[0] 1321 else: 1322 raise ValueError("The split path has unrecognized format!") 1323 all_indices = np.ones(len(self)) 1324 if len(train_indices) == 0: 1325 all_indices[val_indices] = 0 1326 all_indices[test_indices] = 0 1327 train_indices = np.where(all_indices)[0] 1328 elif len(val_indices) == 0: 1329 all_indices[train_indices] = 0 1330 all_indices[test_indices] = 0 1331 val_indices = np.where(all_indices)[0] 1332 elif len(test_indices) == 0: 1333 all_indices[train_indices] = 0 1334 all_indices[val_indices] = 0 1335 test_indices = np.where(all_indices)[0] 1336 else: 1337 raise ValueError( 1338 f"The {method} partition is not recognized, please choose from {options.partition_methods}" 1339 ) 1340 return sorted(train_indices), sorted(test_indices), sorted(val_indices) 1341 1342 def _save_partition( 1343 self, 1344 training: List, 1345 validation: List, 1346 test: List, 1347 split_path: str, 1348 coords: bool = False, 1349 ) -> None: 1350 """Save a split file.""" 1351 if coords: 1352 name = "coords" 1353 training_coords = [] 1354 val_coords = [] 1355 test_coords = [] 1356 for coord in training: 1357 video_id = self.input_store.get_video_id(coord) 1358 clip_id = self.input_store.get_clip_id(coord) 1359 start, end = self.input_store.get_clip_start_end(coord) 1360 training_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1361 for coord in validation: 1362 video_id = self.input_store.get_video_id(coord) 1363 clip_id = self.input_store.get_clip_id(coord) 1364 start, end = self.input_store.get_clip_start_end(coord) 1365 val_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1366 for coord in test: 1367 video_id = self.input_store.get_video_id(coord) 1368 clip_id = self.input_store.get_clip_id(coord) 1369 start, end = self.input_store.get_clip_start_end(coord) 1370 test_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1371 training, validation, test = training_coords, val_coords, test_coords 1372 else: 1373 name = "videos" 1374 if split_path is not None: 1375 with open(split_path, "w") as f: 1376 f.write(f"Training {name}:\n") 1377 for x in training: 1378 f.write(x + "\n") 1379 f.write(f"Validation {name}:\n") 1380 for x in validation: 1381 f.write(x + "\n") 1382 f.write(f"Test {name}:\n") 1383 for x in test: 1384 f.write(x + "\n") 1385 1386 def _get_intervals_from_ind(self, frame_indices: np.ndarray): 1387 """Get a list of intervals from a list of frame indices. 1388 1389 Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`. 1390 1391 Parameters 1392 ---------- 1393 frame_indices : np.ndarray 1394 a list of frame indices 1395 1396 Returns 1397 ------- 1398 intervals : list 1399 a list of interval boundaries 1400 1401 """ 1402 masked_intervals = [] 1403 breaks = np.where(np.diff(frame_indices) != 1)[0] 1404 if len(frame_indices) > 0: 1405 start = frame_indices[0] 1406 for k in breaks: 1407 masked_intervals.append([start, frame_indices[k] + 1]) 1408 start = frame_indices[k + 1] 1409 masked_intervals.append([start, frame_indices[-1] + 1]) 1410 return masked_intervals 1411 1412 def get_intervals(self) -> Tuple[dict, Optional[list]]: 1413 """Get a list of intervals covered by the dataset in the original coordinates. 1414 1415 Returns 1416 ------- 1417 intervals : dict 1418 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1419 values are lists of the intervals in `[start, end]` format 1420 1421 """ 1422 counter = defaultdict(lambda: {}) 1423 coordinates = self.input_store.get_original_coordinates() 1424 for coords in coordinates: 1425 l = self.input_store.get_clip_length_from_coords(coords) 1426 video_name = self.input_store.get_video_id(coords) 1427 clip_id = self.input_store.get_clip_id(coords) 1428 start, end = self.input_store.get_clip_start_end(coords) 1429 if clip_id not in counter[video_name]: 1430 counter[video_name][clip_id] = np.zeros(l) 1431 counter[video_name][clip_id][start:end] = 1 1432 result = {video_name: {} for video_name in counter} 1433 for video_name in counter: 1434 for clip_id in counter[video_name]: 1435 result[video_name][clip_id] = self._get_intervals_from_ind( 1436 np.where(counter[video_name][clip_id])[0] 1437 ) 1438 return result, self.ids 1439 1440 def get_unannotated_intervals(self, first_intervals=None) -> Dict: 1441 """Get a list of intervals in the original coordinates where there is no annotation. 1442 1443 Parameters 1444 ---------- 1445 first_intervals : dict 1446 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1447 values are lists of the intervals in `[start, end]` format. If provided, only the intersection with 1448 those intervals will be returned 1449 1450 Returns 1451 ------- 1452 intervals : dict 1453 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1454 values are lists of the intervals in `[start, end]` format 1455 1456 """ 1457 counter_value = 2 1458 if first_intervals is None: 1459 first_intervals = defaultdict(lambda: defaultdict(lambda: [])) 1460 counter_value = 1 1461 counter = defaultdict(lambda: {}) 1462 coordinates = self.input_store.get_original_coordinates() 1463 for i, coords in enumerate(coordinates): 1464 l = self.input_store.get_clip_length_from_coords(coords) 1465 ann = self.annotation_store[i] 1466 if ( 1467 self.annotation_store.annotation_class() 1468 == "nonexclusive_classification" 1469 ): 1470 ann = ann[0, :] 1471 video_name = self.input_store.get_video_id(coords) 1472 clip_id = self.input_store.get_clip_id(coords) 1473 start, end = self.input_store.get_clip_start_end(coords) 1474 if clip_id not in counter[video_name]: 1475 counter[video_name][clip_id] = np.ones(l) 1476 counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int() 1477 result = {video_name: {} for video_name in counter} 1478 for video_name in counter: 1479 for clip_id in counter[video_name]: 1480 for start, end in first_intervals[video_name][clip_id]: 1481 counter[video_name][clip_id][start:end] += 1 1482 result[video_name][clip_id] = self._get_intervals_from_ind( 1483 np.where(counter[video_name][clip_id] == counter_value)[0] 1484 ) 1485 return result 1486 1487 def get_annotated_intervals(self) -> Dict: 1488 """Get a list of intervals in the original coordinates where there is no annotation. 1489 1490 Returns 1491 ------- 1492 intervals : dict 1493 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1494 values are lists of the intervals in `[start, end]` format 1495 1496 """ 1497 if self.annotation_type == "none": 1498 return [] 1499 counter_value = 1 1500 counter = defaultdict(lambda: {}) 1501 coordinates = self.input_store.get_original_coordinates() 1502 for i, coords in enumerate(coordinates): 1503 l = self.input_store.get_clip_length_from_coords(coords) 1504 ann = self.annotation_store[i] 1505 video_name = self.input_store.get_video_id(coords) 1506 clip_id = self.input_store.get_clip_id(coords) 1507 start, end = self.input_store.get_clip_start_end(coords) 1508 if clip_id not in counter[video_name]: 1509 counter[video_name][clip_id] = np.zeros(l) 1510 if ( 1511 self.annotation_store.annotation_class() 1512 == "nonexclusive_classification" 1513 ): 1514 counter[video_name][clip_id][start:end] = ( 1515 torch.sum(ann[:, : end - start] != -100, dim=0) > 0 1516 ).int() 1517 else: 1518 counter[video_name][clip_id][start:end] = ( 1519 ann[: end - start] != -100 1520 ).int() 1521 result = {video_name: {} for video_name in counter} 1522 for video_name in counter: 1523 for clip_id in counter[video_name]: 1524 result[video_name][clip_id] = self._get_intervals_from_ind( 1525 np.where(counter[video_name][clip_id] == counter_value)[0] 1526 ) 1527 return result 1528 1529 def get_ids(self) -> Dict: 1530 """Get a dictionary of all clip ids in the dataset. 1531 1532 Returns 1533 ------- 1534 ids : dict 1535 a dictionary where keys are video ids and values are lists of clip ids 1536 1537 """ 1538 coordinates = self.input_store.get_original_coordinates() 1539 video_ids = np.array(self.input_store.get_video_id_order()) 1540 id_set = set(video_ids) 1541 result = {} 1542 for video_id in id_set: 1543 coords = coordinates[video_ids == video_id] 1544 clip_ids = list({self.input_store.get_clip_id(c) for c in coords}) 1545 result[video_id] = clip_ids 1546 return result 1547 1548 def get_len(self, video_id: str, clip_id: str) -> int: 1549 """Get the length of a specific clip. 1550 1551 Parameters 1552 ---------- 1553 video_id : str 1554 the video id 1555 clip_id : str 1556 the clip id 1557 1558 Returns 1559 ------- 1560 length : int 1561 the length 1562 1563 """ 1564 return self.input_store.get_clip_length(video_id, clip_id) 1565 1566 def get_confusion_matrix( 1567 self, prediction: torch.Tensor, confusion_type: str = "recall" 1568 ) -> Tuple[ndarray, list]: 1569 """Get a confusion matrix. 1570 1571 Parameters 1572 ---------- 1573 prediction : torch.Tensor 1574 a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)` 1575 confusion_type : {"recall", "precision"} 1576 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 1577 into account, and if `type` is `"precision"`, only false negatives 1578 1579 Returns 1580 ------- 1581 confusion_matrix : np.ndarray 1582 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 1583 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 1584 `N_i` is the number of frames that have the i-th label in the ground truth 1585 classes : list 1586 a list of classes 1587 1588 """ 1589 behaviors_dict = self.annotation_store.behaviors_dict() 1590 num_behaviors = len(behaviors_dict) 1591 confusion_matrix = np.zeros((num_behaviors, num_behaviors)) 1592 if self.annotation_store.annotation_class() == "exclusive_classification": 1593 exclusive = True 1594 confusion_type = None 1595 elif self.annotation_store.annotation_class() == "nonexclusive_classification": 1596 exclusive = False 1597 else: 1598 raise RuntimeError( 1599 f"The {self.annotation_store.annotation_class()} annotation class is not recognized!" 1600 ) 1601 for ann, p in zip(self.annotation_store, prediction): 1602 if exclusive: 1603 class_prediction = torch.max(p, dim=0)[1] 1604 for i in behaviors_dict.keys(): 1605 for j in behaviors_dict.keys(): 1606 confusion_matrix[i, j] += int( 1607 torch.sum(class_prediction[ann == i] == j) 1608 ) 1609 else: 1610 class_prediction = (p > 0.5).int() 1611 for i in behaviors_dict.keys(): 1612 for j in behaviors_dict.keys(): 1613 if confusion_type == "recall": 1614 pred = deepcopy(class_prediction[j]) 1615 if i != j: 1616 pred[ann[j] == 1] = 0 1617 confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1])) 1618 elif confusion_type == "precision": 1619 annotation = deepcopy(ann[j]) 1620 if i != j: 1621 annotation[class_prediction[j] == 1] = 0 1622 confusion_matrix[i, j] += int( 1623 torch.sum(annotation[class_prediction[i] == 1]) 1624 ) 1625 else: 1626 raise ValueError( 1627 f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']" 1628 ) 1629 counter = self.annotation_store.count_classes() 1630 for i in behaviors_dict.keys(): 1631 if counter[i] != 0: 1632 if confusion_type == "recall" or confusion_type is None: 1633 confusion_matrix[i, :] /= counter[i] 1634 else: 1635 confusion_matrix[:, i] /= counter[i] 1636 return confusion_matrix, list(behaviors_dict.values()), confusion_type 1637 1638 def annotation_class(self) -> str: 1639 """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon). 1640 1641 Returns 1642 ------- 1643 annotation_class : str 1644 the type of annotation 1645 1646 """ 1647 return self.annotation_store.annotation_class() 1648 1649 def set_normalization_stats(self, stats: Dict) -> None: 1650 """Set the stats to normalize data at runtime. 1651 1652 Parameters 1653 ---------- 1654 stats : dict 1655 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1656 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1657 1658 """ 1659 self.stats = stats 1660 1661 def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]: 1662 """Get the minimum and maximum frame numbers for each clip in a video. 1663 1664 Parameters 1665 ---------- 1666 video_id : str 1667 the video id 1668 1669 Returns 1670 ------- 1671 min_frames : dict 1672 a dictionary where keys are clip ids and values are the minimum frame numbers 1673 max_frames : dict 1674 a dictionary where keys are clip ids and values are the maximum frame numbers 1675 1676 """ 1677 coords = self.input_store.get_original_coordinates() 1678 clips = set( 1679 [ 1680 self.input_store.get_clip_id(c) 1681 for c in coords 1682 if self.input_store.get_video_id(c) == video_id 1683 ] 1684 ) 1685 min_frames = {} 1686 max_frames = {} 1687 for clip in clips: 1688 start = self.input_store.get_clip_start(video_id, clip) 1689 end = start + self.input_store.get_clip_length(video_id, clip) 1690 min_frames[clip] = start 1691 max_frames[clip] = end - 1 1692 return min_frames, max_frames 1693 1694 def get_normalization_stats(self, skip_keys=None) -> Dict: 1695 """Get mean and standard deviation for each key. 1696 1697 Parameters 1698 ---------- 1699 skip_keys : list, optional 1700 a list of keys to skip 1701 1702 Returns 1703 ------- 1704 stats : dict 1705 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1706 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1707 1708 """ 1709 stats = defaultdict(lambda: {}) 1710 sums = defaultdict(lambda: 0) 1711 if skip_keys is None: 1712 skip_keys = [] 1713 counter = defaultdict(lambda: 0) 1714 for sample in tqdm(self): 1715 for key, value in sample["input"].items(): 1716 key_name = key.split("---")[0] 1717 if key_name not in skip_keys: 1718 sums[key_name] += value[:, value.sum(0) != 0].sum(-1) 1719 counter[key_name] += torch.sum(value.sum(0) != 0) 1720 for key, value in sums.items(): 1721 stats[key]["mean"] = (value / counter[key]).unsqueeze(-1) 1722 sums = defaultdict(lambda: 0) 1723 for sample in tqdm(self): 1724 for key, value in sample["input"].items(): 1725 key_name = key.split("---")[0] 1726 if key_name not in skip_keys: 1727 sums[key_name] += ( 1728 (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2 1729 ).sum(-1) 1730 for key, value in sums.items(): 1731 stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key]) 1732 return stats
33class BehaviorDataset(Dataset, ABC): 34 """A generalized dataset class. 35 36 Data and annotation are stored in separate InputStore and BehaviorStore objects; the dataset class 37 manages their interactions. 38 """ 39 40 def __init__( 41 self, 42 data_type: str, 43 annotation_type: str = "none", 44 ssl_transformations: List = None, 45 saved_data_path: str = None, 46 input_store: InputStore = None, 47 annotation_store: BehaviorStore = None, 48 only_load_annotated: bool = False, 49 recompute_annotation: bool = False, 50 # mask: str = None, 51 ids: List = None, 52 **data_parameters, 53 ) -> None: 54 """Initialize a dataset. 55 56 Parameters 57 ---------- 58 data_type : str 59 the data type (see available types by running BehaviorDataset.data_types()) 60 annotation_type : str 61 the annotation type (see available types by running BehaviorDataset.annotation_types()) 62 ssl_transformations : list 63 a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple 64 saved_data_path : str 65 the path to a pre-computed pickled dataset 66 input_store : InputStore 67 a pre-computed input store 68 annotation_store : BehaviorStore 69 a precomputed annotation store 70 only_load_annotated : bool 71 if `True`, the input files that don't have a matching annotation file will be disregarded 72 recompute_annotation : bool 73 if `True`, the annotation will be recomputed even if a precomputed annotation store is provided 74 ids : list 75 a list of ids to load from the input store 76 *data_parameters : dict 77 parameters to initialize the input and annotation stores 78 79 """ 80 mask = None 81 if len(data_parameters) == 0: 82 recompute_annotation = False 83 feature_extraction = data_parameters.get("feature_extraction") 84 if feature_extraction is not None and not issubclass( 85 options.input_stores[data_type], 86 options.feature_extractors[feature_extraction].input_store_class, 87 ): 88 raise ValueError( 89 f"The {feature_extraction} feature extractor does not work with " 90 f"the {data_type} data type, please choose a suclass of " 91 f"{options.feature_extractors[feature_extraction].input_store_class}" 92 ) 93 if ssl_transformations is None: 94 ssl_transformations = [] 95 self.ssl_transformations = ssl_transformations 96 self.input_type = data_type 97 self.annotation_type = annotation_type 98 self.stats = None 99 if mask is not None: 100 with open(mask, "rb") as f: 101 self.mask = pickle.load(f) 102 else: 103 self.mask = None 104 self.ids = ids 105 self.tag = None 106 self.return_unlabeled = None 107 # load saved key objects for annotation and input if they exist 108 input_key_objects, annotation_key_objects = None, None 109 if saved_data_path is not None: 110 if os.path.exists(saved_data_path): 111 with open(saved_data_path, "rb") as f: 112 input_key_objects, annotation_key_objects = pickle.load(f) 113 # if the input or the annotation store need to be created, generate the common video order 114 if len(data_parameters) > 0: 115 input_files = options.input_stores[data_type].get_file_ids( 116 **data_parameters 117 ) 118 annotation_files = options.annotation_stores[annotation_type].get_file_ids( 119 **data_parameters 120 ) 121 if only_load_annotated: 122 data_parameters["video_order"] = [ 123 x for x in input_files if x in annotation_files 124 ] 125 else: 126 data_parameters["video_order"] = input_files 127 if len(data_parameters["video_order"]) == 0: 128 raise RuntimeError( 129 "The length of file list is 0! Please check your data parameters!" 130 ) 131 data_parameters["mask"] = self.mask 132 # load or create the input store 133 ok = False 134 if input_store is not None: 135 self.input_store = input_store 136 ok = True 137 elif input_key_objects is not None: 138 try: 139 self.input_store = self._load_input_store(data_type, input_key_objects) 140 ok = True 141 except: 142 warnings.warn("Loading input store from key objects failed") 143 if not ok: 144 self.input_store = self._get_input_store( 145 data_type, deepcopy(data_parameters) 146 ) 147 # get the objects needed to create the annotation store (like a clip length dictionary) 148 annotation_objects = self.input_store.get_annotation_objects() 149 data_parameters.update(annotation_objects) 150 # load or create the annotation store 151 ok = False 152 if annotation_store is not None: 153 self.annotation_store = annotation_store 154 ok = True 155 elif ( 156 (annotation_key_objects is not None) 157 and mask is None 158 and not recompute_annotation 159 ): 160 if len(annotation_key_objects) > 0: 161 try: 162 self.annotation_store = self._load_annotation_store( 163 annotation_type, annotation_key_objects 164 ) 165 ok = True 166 except: 167 warnings.warn("Loading annotation store from key objects failed") 168 if not ok: 169 self.annotation_store = self._get_annotation_store( 170 annotation_type, deepcopy(data_parameters) 171 ) 172 to_remove = self.annotation_store.filtered_indices() 173 if len(to_remove) > 0: 174 print( 175 f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples" 176 ) 177 if len(to_remove) == len(self.annotation_store) and len(to_remove) > 0: 178 raise ValueError("All samples were filtered out!") 179 180 if len(self.input_store) == len(self.annotation_store): 181 self.input_store.remove(to_remove) 182 self.annotation_store.remove(to_remove) 183 self.input_indices = list(range(len(self.input_store))) 184 self.annotation_indices = list(range(len(self.input_store))) 185 self.indices = list(range(len(self.input_store))) 186 187 def __getitem__(self, item: int) -> Dict: 188 idx = self._get_idx(item) 189 input = deepcopy(self.input_store[idx]) 190 target = self.annotation_store[idx] 191 tag = self.input_store.get_tag(idx) 192 ssl_inputs, ssl_targets = self._get_SSL_targets(input) 193 batch = {"input": input} 194 for name, x in zip( 195 ["target", "ssl_inputs", "ssl_targets", "tag"], 196 [target, ssl_inputs, ssl_targets, tag], 197 ): 198 if x is not None: 199 batch[name] = x 200 batch["index"] = idx 201 if self.stats is not None: 202 for key in batch["input"].keys(): 203 key_name = key.split("---")[0] 204 if key_name in self.stats: 205 batch["input"][key][:, batch["input"][key].sum(0) != 0] = ( 206 (batch["input"][key] - self.stats[key_name]["mean"]) 207 / (self.stats[key_name]["std"] + 1e-7) 208 )[:, batch["input"][key].sum(0) != 0] 209 return batch 210 211 def __len__(self) -> int: 212 return len(self.indices) 213 # if self.annotation_type != "none": 214 # return self.annotation_store.get_len(return_unlabeled=self.return_unlabeled) 215 # else: 216 # return len(self.input_store) 217 218 def get_tags(self) -> List: 219 """Get a list of all meta tags. 220 221 Returns 222 ------- 223 tags: List 224 a list of unique meta tag values 225 226 """ 227 return self.input_store.get_tags() 228 229 def save(self, save_path: str) -> None: 230 """Save the dictionary. 231 232 Parameters 233 ---------- 234 save_path : str 235 the path where the pickled file will be stored 236 237 """ 238 input_obj = self.input_store.key_objects() 239 annotation_obj = self.annotation_store.key_objects() 240 with open(save_path, "wb") as f: 241 pickle.dump((input_obj, annotation_obj), f) 242 243 def to_ram(self) -> None: 244 """Transfer the dataset to RAM.""" 245 self.input_store.to_ram() 246 self.annotation_store.to_ram() 247 248 def generate_full_length_gt(self) -> Dict: 249 """Generate full-length ground truth from the annotations. 250 251 Returns 252 ------- 253 full_length_gt : dict 254 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 255 values are the ground truth labels 256 257 """ 258 if self.annotation_class() == "exclusive_classification": 259 gt = torch.zeros((len(self), self.len_segment())) 260 else: 261 gt = torch.zeros( 262 (len(self), len(self.behaviors_dict()), self.len_segment()) 263 ) 264 for i in range(len(self)): 265 gt[i] = self.annotation_store[i] 266 return self.generate_full_length_prediction(gt) 267 268 def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict: 269 """Map predictions for the equal-length pieces to predictions for the original data. 270 271 Probabilities are averaged over predictions on overlapping intervals. 272 273 Parameters 274 ---------- 275 predicted: torch.Tensor 276 a tensor of predicted probabilities of shape `(N, #classes, #frames)` 277 278 Returns 279 ------- 280 full_length_prediction : dict 281 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are 282 averaged probability tensors 283 284 """ 285 result = defaultdict(lambda: {}) 286 counter = defaultdict(lambda: {}) 287 coordinates = self.input_store.get_original_coordinates() 288 for coords, prediction in zip(coordinates, predicted): 289 l = self.input_store.get_clip_length_from_coords(coords) 290 video_name = self.input_store.get_video_id(coords) 291 clip_id = self.input_store.get_clip_id(coords) 292 start, end = self.input_store.get_clip_start_end(coords) 293 if clip_id not in result[video_name].keys(): 294 result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 295 counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 296 result[video_name][clip_id][..., start:end] += ( 297 prediction.squeeze()[..., : end - start].detach().cpu() 298 ) 299 counter[video_name][clip_id][..., start:end] += 1 300 for video_name in result: 301 for clip_id in result[video_name]: 302 result[video_name][clip_id] /= counter[video_name][clip_id] 303 result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100 304 result = dict(result) 305 return result 306 307 def find_valleys( 308 self, 309 predicted: Union[torch.Tensor, Dict], 310 threshold: float = 0.5, 311 min_frames: int = 0, 312 visibility_min_score: float = 0, 313 visibility_min_frac: float = 0, 314 main_class: int = 1, 315 low: bool = True, 316 predicted_error: torch.Tensor = None, 317 error_threshold: float = 0.5, 318 hysteresis: bool = False, 319 threshold_diff: float = None, 320 min_frames_error: int = None, 321 smooth_interval: int = 1, 322 cut_annotated: bool = False, 323 ) -> Dict: 324 """Find the intervals where the probability of a certain class is below or above a certain hard_threshold. 325 326 Parameters 327 ---------- 328 predicted : torch.Tensor | dict 329 either a tensor of predictions for the data prompts or the output of 330 `BehaviorDataset.generate_full_length_prediction` 331 threshold : float, default 0.5 332 the main hard_threshold 333 min_frames : int, default 0 334 the minimum length of the intervals 335 visibility_min_score : float, default 0 336 the minimum visibility score in the intervals 337 visibility_min_frac : float, default 0 338 fraction of the interval that has to have the visibility score larger than visibility_score_thr 339 main_class : int, default 1 340 the index of the class the function is inspecting 341 low : bool, default True 342 if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above 343 predicted_error : torch.Tensor, optional 344 a tensor of error predictions for the data prompts 345 error_threshold : float, default 0.5 346 maximum possible probability of error at the intervals 347 hysteresis: bool, default False 348 if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff 349 threshold_diff: float, optional 350 the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes 351 min_frames_error: int, optional 352 if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames 353 smooth_interval: int, default 1 354 the number of frames to smooth the predictions over 355 cut_annotated: bool, default False 356 if `True`, annotated intervals will be cut out of the predicted intervals 357 358 Returns 359 ------- 360 valleys : dict 361 a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals 362 363 """ 364 result = defaultdict(lambda: []) 365 if type(predicted) is not dict: 366 predicted = self.generate_full_length_prediction(predicted) 367 if predicted_error is not None: 368 predicted_error = self.generate_full_length_prediction(predicted_error) 369 elif min_frames_error is not None and min_frames_error != 0: 370 # warnings.warn( 371 # f"The min_frames_error parameter is set to {min_frames_error} but no error prediction " 372 # f"is given! Setting min_frames_error to 0." 373 # ) 374 min_frames_error = 0 375 if low and hysteresis and threshold_diff is None: 376 raise ValueError( 377 "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff." 378 ) 379 if cut_annotated: 380 masked_intervals_dict = self.get_annotated_intervals() 381 else: 382 masked_intervals_dict = None 383 print("Valleys found:") 384 for v_id in predicted: 385 for clip_id in predicted[v_id].keys(): 386 if predicted_error is not None: 387 error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold 388 if min_frames_error is not None: 389 output, indices, counts = torch.unique_consecutive( 390 error_mask, return_inverse=True, return_counts=True 391 ) 392 wrong_indices = torch.where( 393 output * (counts < min_frames_error) 394 )[0] 395 if len(wrong_indices) > 0: 396 for i in wrong_indices: 397 error_mask[indices == i] = False 398 else: 399 error_mask = None 400 if masked_intervals_dict is not None: 401 masked_intervals = masked_intervals_dict[v_id][clip_id] 402 else: 403 masked_intervals = None 404 if not hysteresis: 405 res_indices_start, res_indices_end = apply_threshold( 406 predicted[v_id][clip_id][main_class, :], 407 threshold, 408 low, 409 error_mask, 410 min_frames, 411 smooth_interval, 412 masked_intervals, 413 ) 414 elif threshold_diff is not None: 415 if low: 416 soft_threshold = threshold + threshold_diff 417 else: 418 soft_threshold = threshold - threshold_diff 419 res_indices_start, res_indices_end = apply_threshold_hysteresis( 420 predicted[v_id][clip_id][main_class, :], 421 soft_threshold, 422 threshold, 423 low, 424 error_mask, 425 min_frames, 426 smooth_interval, 427 masked_intervals, 428 ) 429 else: 430 res_indices_start, res_indices_end = apply_threshold_max( 431 predicted[v_id][clip_id], 432 threshold, 433 main_class, 434 error_mask, 435 min_frames, 436 smooth_interval, 437 masked_intervals, 438 ) 439 start = self.input_store.get_clip_start(v_id, clip_id) 440 result[v_id] += [ 441 [i + start, j + start, clip_id] 442 for i, j in zip(res_indices_start, res_indices_end) 443 if self.input_store.get_visibility( 444 v_id, clip_id, i, j, visibility_min_score 445 ) 446 > visibility_min_frac 447 ] 448 result[v_id] = sorted(result[v_id]) 449 print(f" {v_id}: {len(result[v_id])}") 450 return dict(result) 451 452 def valleys_union(self, valleys_list) -> Dict: 453 """Find the intersection of two valleys dictionaries. 454 455 Parameters 456 ---------- 457 valleys_list : list 458 a list of valleys dictionaries 459 460 Returns 461 ------- 462 intersection : dict 463 a new valleys dictionary with the intersection of the input intervals 464 465 """ 466 valleys_list = [x for x in valleys_list if x is not None] 467 if len(valleys_list) == 1: 468 return valleys_list[0] 469 elif len(valleys_list) == 0: 470 return {} 471 union = {} 472 keys_list = [set(valleys.keys()) for valleys in valleys_list] 473 keys = set.union(*keys_list) 474 for v_id in keys: 475 res = [] 476 clips_list = [ 477 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 478 ] 479 clips = set.union(*clips_list) 480 for clip_id in clips: 481 clip_intervals = [ 482 x 483 for valleys in valleys_list 484 for x in valleys[v_id] 485 if x[-1] == clip_id 486 ] 487 v_len = self.input_store.get_clip_length(v_id, clip_id) 488 arr = torch.zeros(v_len) 489 for start, end, _ in clip_intervals: 490 arr[start:end] += 1 491 output, indices, counts = torch.unique_consecutive( 492 arr > 0, return_inverse=True, return_counts=True 493 ) 494 long_indices = torch.where(output)[0] 495 res += [ 496 ( 497 (indices == i).nonzero(as_tuple=True)[0][0].item(), 498 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 499 clip_id, 500 ) 501 for i in long_indices 502 ] 503 union[v_id] = res 504 return union 505 506 def valleys_intersection(self, valleys_list) -> Dict: 507 """Find the intersection of two valleys dictionaries. 508 509 Parameters 510 ---------- 511 valleys_list : list 512 a list of valleys dictionaries 513 514 Returns 515 ------- 516 intersection : dict 517 a new valleys dictionary with the intersection of the input intervals 518 519 """ 520 valleys_list = [x for x in valleys_list if x is not None] 521 if len(valleys_list) == 1: 522 return valleys_list[0] 523 elif len(valleys_list) == 0: 524 return {} 525 intersection = {} 526 keys_list = [set(valleys.keys()) for valleys in valleys_list] 527 keys = set.intersection(*keys_list) 528 for v_id in keys: 529 res = [] 530 clips_list = [ 531 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 532 ] 533 clips = set.intersection(*clips_list) 534 for clip_id in clips: 535 clip_intervals = [ 536 x 537 for valleys in valleys_list 538 for x in valleys[v_id] 539 if x[-1] == clip_id 540 ] 541 v_len = self.input_store.get_clip_length(v_id, clip_id) 542 arr = torch.zeros(v_len) 543 for start, end, _ in clip_intervals: 544 arr[start:end] += 1 545 output, indices, counts = torch.unique_consecutive( 546 arr, return_inverse=True, return_counts=True 547 ) 548 long_indices = torch.where(output == 2)[0] 549 res += [ 550 ( 551 (indices == i).nonzero(as_tuple=True)[0][0].item(), 552 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 553 clip_id, 554 ) 555 for i in long_indices 556 ] 557 intersection[v_id] = res 558 return intersection 559 560 def partition_train_test_val( 561 self, 562 use_test: float = 0, 563 split_path: str = None, 564 method: str = "random", 565 val_frac: float = 0, 566 test_frac: float = 0, 567 save_split: bool = False, 568 normalize: bool = False, 569 skip_normalization_keys: List = None, 570 stats: Dict = None, 571 ) -> Tuple: 572 """Partition the dataset into three new datasets. 573 574 Parameters 575 ---------- 576 use_test : float, default 0 577 The fraction of the test dataset to be used in training without labels 578 split_path : str, optional 579 The path to load the split information from (if `'file'` method is used) and to save it to 580 (if `'save_split'` is `True`) 581 method : {'random', 'random:test-from-name', 'random:test-from-name:{name}', 582 'val-from-name:{val_name}:test-from-name:{test_name}', 583 'random:equalize:segments', 'random:equalize:videos', 584 'folders', 'time', 'time:strict', 'file'} 585 The partitioning method: 586 - `'random'`: sort videos into subsets randomly, 587 - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation 588 subsets randomly and create 589 the test subset from the video ids that start with a speific substring (`'test'` by default, or `name` 590 if provided), 591 - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but 592 making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain 593 occurrences of the class get into the validation subset and `0.8 * test_frac` get into the test subset; 594 this in ensured for all classes in order of increasing number of occurrences until the validation and 595 test subsets are full 596 - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test 597 subsets from the video ids that start with specific substrings (`val_name` for validation 598 and `test_name` for test) and sort all other videos into the training subset 599 - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets, 600 - `'time'`: split each video into training, validation and test subsequences, 601 - `'time:strict'`: split each video into validation, test and training subsequences 602 and throw out the last segments in validation and test (to get rid of overlaps), 603 - `'file'`: split according to a split file. 604 val_frac : float, default 0 605 The fraction of the dataset to be used in validation 606 test_frac : float, default 0 607 The fraction of the dataset to be used in test 608 save_split : bool, default False 609 Save a split file if True 610 normalize : bool, default False 611 Normalize the dataset if `True` 612 skip_normalization_keys : list, optional 613 A list of keys to skip normalization for 614 stats : dict, optional 615 A dictionary of (pre-computed) statistics to use for normalization 616 617 Returns 618 ------- 619 train_dataset : BehaviorDataset 620 train dataset 621 val_dataset : BehaviorDataset 622 validation dataset 623 test_dataset : BehaviorDataset 624 test dataset 625 626 """ 627 train_indices, test_indices, val_indices = self._partition_indices( 628 split_path=split_path, 629 method=method, 630 val_frac=val_frac, 631 test_frac=test_frac, 632 save_split=save_split, 633 ) 634 ssl_indices = None 635 partition_method = method.split(":") 636 if ( 637 partition_method[0] in ("leave-one-in", "leave-n-in") 638 and len(partition_method) > 1 639 and partition_method[2] == "val-for-ssl" 640 ): 641 print("Using validation samples for SSL!") 642 ssl_indices = val_indices 643 644 val_dataset = self._create_new_dataset(val_indices) 645 test_dataset = self._create_new_dataset(test_indices) 646 train_dataset = self._create_new_dataset(train_indices, ssl_indices=ssl_indices) 647 648 train_classes = train_dataset.count_classes() 649 val_classes = val_dataset.count_classes() 650 test_classes = test_dataset.count_classes() 651 print("Number of samples:") 652 print(f" validation:") 653 print(f" {[f'{k}: {val_classes[k]}' for k in sorted(val_classes.keys())]}") 654 print(f" training:") 655 print(f" {[f'{k}: {train_classes[k]}' for k in sorted(train_classes.keys())]}") 656 print(f" test:") 657 print(f" {[f'{k}: {test_classes[k]}' for k in sorted(test_classes.keys())]}") 658 if normalize: 659 if stats is None: 660 print("Computing normalization statistics...") 661 stats = train_dataset.get_normalization_stats(skip_normalization_keys) 662 else: 663 print("Setting loaded normalization statistics...") 664 train_dataset.set_normalization_stats(stats) 665 val_dataset.set_normalization_stats(stats) 666 test_dataset.set_normalization_stats(stats) 667 return train_dataset, test_dataset, val_dataset 668 669 def class_weights(self, proportional=False) -> List: 670 """Calculate class weights in inverse proportion to number of samples. 671 672 Parameters 673 ---------- 674 proportional : bool, default False 675 If `True`, the weights are proportional to the number of samples in the most common class 676 677 Returns 678 ------- 679 weights: list 680 a list of class weights 681 682 """ 683 items = sorted( 684 [ 685 (k, v) 686 for k, v in self.annotation_store.count_classes().items() 687 if k != -100 688 ] 689 ) 690 if self.annotation_store.annotation_class() == "exclusive_classification": 691 if not proportional: 692 numerator = len(self.annotation_store) 693 else: 694 numerator = max([x[1] for x in items]) 695 weights = [numerator / (v + 1e-7) for _, v in items] 696 else: 697 items_zero = sorted( 698 [ 699 (k, v) 700 for k, v in self.annotation_store.count_classes(zeros=True).items() 701 if k != -100 702 ] 703 ) 704 if not proportional: 705 numerators = defaultdict(lambda: len(self.annotation_store)) 706 else: 707 numerators = { 708 item_one[0]: max(item_one[1], item_zero[1]) 709 for item_one, item_zero in zip(items, items_zero) 710 } 711 weights = {} 712 weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero] 713 weights[1] = [numerators[k] / (v + 1e-7) for k, v in items] 714 return weights 715 716 def _boundary_class_weight(self): 717 """Calculate the weight of the boundary class. 718 719 Returns 720 ------- 721 weight: float 722 the weight of the boundary class 723 724 """ 725 if self.annotation_type != "none": 726 f = self.annotation_store.data.flatten() 727 _, inv = torch.unique_consecutive(f, return_inverse=True) 728 boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape( 729 self.annotation_store.data.shape 730 ) 731 boundary[..., 0] = 0 732 cnt = Counter(boundary.flatten().numpy()) 733 return cnt[1] / cnt[0] 734 else: 735 return 0 736 737 def count_classes(self, bouts: bool = False) -> Dict: 738 """Get a class counter dictionary. 739 740 Parameters 741 ---------- 742 bouts : bool, default False 743 if `True`, instead of frame counts segment counts are returned 744 745 Returns 746 ------- 747 count_dictionary : dict 748 a dictionary with class indices as keys and frame or bout counts as values 749 750 """ 751 return self.annotation_store.count_classes(bouts=bouts) 752 753 def behaviors_dict(self) -> Dict: 754 """Get a behavior dictionary. 755 756 Returns 757 ------- 758 dict 759 behavior dictionary 760 761 """ 762 return self.annotation_store.behaviors_dict() 763 764 def bodyparts_order(self) -> List: 765 """Get the order of bodyparts. 766 767 Returns 768 ------- 769 bodyparts : List 770 a list of bodyparts 771 772 """ 773 try: 774 return self.input_store.get_bodyparts() 775 except: 776 raise RuntimeError( 777 f"The {self.input_type} input store does not have bodyparts implemented!" 778 ) 779 780 def features_shape(self) -> Dict: 781 """Get the shapes of the input features. 782 783 Returns 784 ------- 785 shapes : Dict 786 a dictionary with the shapes of the features 787 788 """ 789 sample = self.input_store[0] 790 shapes = {k: v.shape for k, v in sample.items()} 791 # for key, value in shapes.items(): 792 # print(f'{key}: {value}') 793 return shapes 794 795 def num_classes(self) -> int: 796 """Get the number of classes in the data. 797 798 Returns 799 ------- 800 num_classes : int 801 the number of classes 802 803 """ 804 return len(self.annotation_store.behaviors_dict()) 805 806 def len_segment(self) -> int: 807 """Get the segment length in the data. 808 809 Returns 810 ------- 811 len_segment : int 812 the segment length 813 814 """ 815 sample = self.input_store[0] 816 key = list(sample.keys())[0] 817 return sample[key].shape[-1] 818 819 def set_ssl_transformations(self, ssl_transformations: List) -> None: 820 """Set new SSL transformations. 821 822 Parameters 823 ---------- 824 ssl_transformations : list 825 a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets 826 lists 827 828 """ 829 self.ssl_transformations = ssl_transformations 830 831 @classmethod 832 def new(cls, *args, **kwargs): 833 """Create a new object of the same class. 834 835 Parameters 836 ---------- 837 args : list 838 arguments for the constructor 839 kwargs : dict 840 keyword arguments for the constructor 841 842 Returns 843 ------- 844 new_instance: BehaviorDataset 845 a new instance of the same class 846 847 """ 848 return cls(*args, **kwargs) 849 850 @classmethod 851 def get_parameters(cls, data_type: str, annotation_type: str) -> List: 852 """Get parameters necessary for initialization. 853 854 Parameters 855 ---------- 856 data_type : str 857 the data type 858 annotation_type : str 859 the annotation type 860 861 Returns 862 ------- 863 parameters : list 864 a list of parameters 865 866 """ 867 input_features = options.input_stores[data_type].get_parameters() 868 annotation_features = options.annotation_stores[ 869 annotation_type 870 ].get_parameters() 871 self_features = inspect.getfullargspec(cls.__init__).args 872 return self_features + input_features + annotation_features 873 874 @staticmethod 875 def data_types() -> List: 876 """List available data types. 877 878 Returns 879 ------- 880 data_types : list 881 available data types 882 883 """ 884 return list(options.input_stores.keys()) 885 886 @staticmethod 887 def annotation_types() -> List: 888 """List available annotation types. 889 890 Returns 891 ------- 892 annotation_types : list 893 available annotation types 894 895 """ 896 return list(options.annotation_stores.keys()) 897 898 def _get_SSL_targets(self, input: Dict) -> Tuple[List, List]: 899 """Get the SSL inputs and targets from a sample dictionary.""" 900 ssl_inputs = [] 901 ssl_targets = [] 902 for transform in self.ssl_transformations: 903 ssl_input, ssl_target = transform(copy(input)) 904 ssl_inputs.append(ssl_input) 905 ssl_targets.append(ssl_target) 906 return ssl_inputs, ssl_targets 907 908 def _create_new_dataset(self, indices: List, ssl_indices: List = None): 909 """Create a subsample of the dataset, with samples at ssl_indices losing the annotation.""" 910 if ssl_indices is None: 911 ssl_indices = [] 912 input_store = self.input_store.create_subsample(indices, ssl_indices) 913 annotation_store = self.annotation_store.create_subsample(indices, ssl_indices) 914 new = self.new( 915 data_type=self.input_type, 916 annotation_type=self.annotation_type, 917 ssl_transformations=self.ssl_transformations, 918 annotation_store=annotation_store, 919 input_store=input_store, 920 ids=list(indices) + list(ssl_indices), 921 recompute_annotation=False, 922 ) 923 return new 924 925 def _load_input_store(self, data_type: str, key_objects: Tuple) -> InputStore: 926 """Load input store from key objects.""" 927 input_store = options.input_stores[data_type](key_objects=key_objects) 928 return input_store 929 930 def _load_annotation_store( 931 self, annotation_type: str, key_objects: Tuple 932 ) -> BehaviorStore: 933 """Load annotation store from key objects.""" 934 annotation_store = options.annotation_stores[annotation_type]( 935 key_objects=key_objects 936 ) 937 return annotation_store 938 939 def _get_input_store(self, data_type: str, data_parameters: Dict) -> InputStore: 940 """Create input store from parameters.""" 941 data_parameters["key_objects"] = None 942 input_store = options.input_stores[data_type](**data_parameters) 943 return input_store 944 945 def _get_annotation_store( 946 self, annotation_type: str, data_parameters: Dict 947 ) -> BehaviorStore: 948 """Create annotation store from parameters.""" 949 annotation_store = options.annotation_stores[annotation_type](**data_parameters) 950 return annotation_store 951 952 def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None: 953 """Set the parameters that change the subset that is returned at `__getitem__`. 954 955 Parameters 956 ---------- 957 unlabeled : bool 958 a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and 959 all if `None` 960 tag : int 961 if not `None`, only samples with this meta tag will be returned 962 963 """ 964 if unlabeled != self.return_unlabeled: 965 self.annotation_indices = self.annotation_store.get_indices(unlabeled) 966 self.return_unlabeled = unlabeled 967 if tag != self.tag: 968 self.input_indices = self.input_store.get_indices(tag) 969 self.tag = tag 970 self.indices = [x for x in self.annotation_indices if x in self.input_indices] 971 972 def _get_idx(self, index: int) -> int: 973 """Get index in full dataset.""" 974 return self.indices[index] 975 976 # return self.annotation_store.get_idx( 977 # index, return_unlabeled=self.return_unlabeled 978 # ) 979 980 def _partition_indices( 981 self, 982 split_path: str = None, 983 method: str = "random", 984 val_frac: float = 0, 985 test_frac: float = 0, 986 save_split: bool = False, 987 ) -> Tuple[List, List, List]: 988 """Partition indices into train, validation, test subsets.""" 989 if self.mask is not None: 990 val_indices = self.mask["val_ids"] 991 train_indices = [x for x in range(len(self)) if x not in val_indices] 992 test_indices = [] 993 elif method == "random": 994 videos = np.array(self.input_store.get_video_id_order()) 995 all_videos = list(set(videos)) 996 if len(all_videos) == 1: 997 warnings.warn( 998 "There is only one video in the dataset, so train/val/test split is done on segments; " 999 'that might lead to overlaps, please consider using "time" or "time:strict" as the ' 1000 "partitioning method instead" 1001 ) 1002 # Quick fix for single video: the problem with this is that the segments can overlap 1003 # length = int(self.input_store.get_original_coordinates()[-1][1]) # number of segments 1004 length = len(self.input_store.get_original_coordinates()) 1005 val_len = int(val_frac * length) 1006 test_len = int(test_frac * length) 1007 all_indices = np.random.choice(np.arange(length), length, replace=False) 1008 val_indices = all_indices[:val_len] 1009 test_indices = all_indices[val_len : val_len + test_len] 1010 train_indices = all_indices[val_len + test_len :] 1011 coords = self.input_store.get_original_coordinates() 1012 if save_split: 1013 self._save_partition( 1014 coords[train_indices], 1015 coords[val_indices], 1016 coords[test_indices], 1017 split_path, 1018 coords=True, 1019 ) 1020 else: 1021 length = len(all_videos) 1022 val_len = int(val_frac * length) 1023 test_len = int(test_frac * length) 1024 validation = all_videos[:val_len] 1025 test = all_videos[val_len : val_len + test_len] 1026 training = all_videos[val_len + test_len :] 1027 train_indices = np.where(np.isin(videos, training))[0] 1028 val_indices = np.where(np.isin(videos, validation))[0] 1029 test_indices = np.where(np.isin(videos, test))[0] 1030 if save_split: 1031 self._save_partition(training, validation, test, split_path) 1032 elif method.startswith("random:equalize"): 1033 counter = self.count_classes() 1034 counter = sorted(list([(v, k) for k, v in counter.items()])) 1035 classes = [x[1] for x in counter] 1036 indicator = {c: [] for c in classes} 1037 if method.endswith("videos"): 1038 videos = np.array(self.input_store.get_video_id_order()) 1039 all_videos = list(set(videos)) 1040 total_len = len(all_videos) 1041 for video_id in all_videos: 1042 video_coords = np.where(videos == video_id)[0] 1043 ann = torch.cat( 1044 [self.annotation_store[i] for i in video_coords], dim=-1 1045 ) 1046 for c in classes: 1047 if self.annotation_class() == "nonexclusive_classification": 1048 indicator[c].append(torch.sum(ann[c] == 1) > 0) 1049 elif self.annotation_class() == "exclusive_classification": 1050 indicator[c].append(torch.sum(ann == c) > 0) 1051 else: 1052 raise ValueError( 1053 f"The random:equalize partition method is not implemented" 1054 f"for the {self.annotation_class()} annotation class" 1055 ) 1056 elif method.endswith("segments"): 1057 total_len = len(self) 1058 for ann in self.annotation_store: 1059 for c in classes: 1060 if self.annotation_class() == "nonexclusive_classification": 1061 indicator[c].append(torch.sum(ann[c] == 1) > 0) 1062 elif self.annotation_class() == "exclusive_classification": 1063 indicator[c].append(torch.sum(ann == c) > 0) 1064 else: 1065 raise ValueError( 1066 f"The random:equalize partition method is not implemented" 1067 f"for the {self.annotation_class()} annotation class" 1068 ) 1069 else: 1070 values = [] 1071 for v in options.partition_methods.values(): 1072 values += v 1073 raise ValueError( 1074 f"The {method} partition method is not recognized; please choose from {values}" 1075 ) 1076 val_indices = [] 1077 test_indices = [] 1078 for c in classes: 1079 indicator[c] = np.array(indicator[c]) 1080 ind = np.where(indicator[c])[0] 1081 np.random.shuffle(ind) 1082 c_sum = len(ind) 1083 in_val = np.sum(indicator[c][val_indices]) 1084 in_test = np.sum(indicator[c][test_indices]) 1085 while ( 1086 len(val_indices) < val_frac * total_len 1087 and in_val < val_frac * c_sum * 0.8 1088 ): 1089 first, ind = ind[0], ind[1:] 1090 val_indices = list(set(val_indices).union({first})) 1091 in_val = np.sum(indicator[c][val_indices]) 1092 while ( 1093 len(test_indices) < test_frac * total_len 1094 and in_test < test_frac * c_sum * 0.8 1095 ): 1096 first, ind = ind[0], ind[1:] 1097 test_indices = list(set(test_indices).union({first})) 1098 in_test = np.sum(indicator[c][test_indices]) 1099 if len(val_indices) < int(val_frac * total_len): 1100 left_val = int(val_frac * total_len) - len(val_indices) 1101 else: 1102 left_val = 0 1103 if len(test_indices) < int(test_frac * total_len): 1104 left_test = int(test_frac * total_len) - len(test_indices) 1105 else: 1106 left_test = 0 1107 indicator = np.ones(total_len) 1108 indicator[val_indices] = 0 1109 indicator[test_indices] = 0 1110 ind = np.where(indicator)[0] 1111 np.random.shuffle(ind) 1112 val_indices += list(ind[:left_val]) 1113 test_indices += list(ind[left_val : left_val + left_test]) 1114 train_indices = list(ind[left_val + left_test :]) 1115 if save_split: 1116 if method.endswith("segments"): 1117 coords = self.input_store.get_original_coordinates() 1118 self._save_partition( 1119 coords[train_indices], 1120 coords[val_indices], 1121 coords[test_indices], 1122 coords[split_path], 1123 coords=True, 1124 ) 1125 else: 1126 all_videos = np.array(all_videos) 1127 validation = all_videos[val_indices] 1128 test = all_videos[test_indices] 1129 training = all_videos[train_indices] 1130 self._save_partition(training, validation, test, split_path) 1131 elif method.startswith("random:test-from-name"): 1132 split = method.split(":") 1133 if len(split) > 2: 1134 test_name = split[-1] 1135 else: 1136 test_name = "test" 1137 videos = np.array(self.input_store.get_video_id_order()) 1138 all_videos = list(set(videos)) 1139 test = [] 1140 train_videos = [] 1141 for x in all_videos: 1142 if x.startswith(test_name): 1143 test.append(x) 1144 else: 1145 train_videos.append(x) 1146 length = len(train_videos) 1147 val_len = int(val_frac * length) 1148 validation = train_videos[:val_len] 1149 training = train_videos[val_len:] 1150 train_indices = np.where(np.isin(videos, training))[0] 1151 val_indices = np.where(np.isin(videos, validation))[0] 1152 test_indices = np.where(np.isin(videos, test))[0] 1153 if save_split: 1154 self._save_partition(training, validation, test, split_path) 1155 elif method.startswith("val-from-name"): 1156 split = method.split(":") 1157 if split[2] != "test-from-name": 1158 raise ValueError( 1159 f"The {method} partition method is not recognized, please choose from {options.partition_methods}" 1160 ) 1161 val_name = split[1] 1162 test_name = split[-1] 1163 videos = np.array(self.input_store.get_video_id_order()) 1164 all_videos = list(set(videos)) 1165 test = [] 1166 validation = [] 1167 training = [] 1168 for x in all_videos: 1169 if x.startswith(test_name): 1170 test.append(x) 1171 elif x.startswith(val_name): 1172 validation.append(x) 1173 else: 1174 training.append(x) 1175 train_indices = np.where(np.isin(videos, training))[0] 1176 val_indices = np.where(np.isin(videos, validation))[0] 1177 test_indices = np.where(np.isin(videos, test))[0] 1178 elif method == "folders": 1179 folders = np.array(self.input_store.get_folder_order()) 1180 videos = np.array(self.input_store.get_video_id_order()) 1181 train_indices = np.where(np.isin(folders, ["training", "train"]))[0] 1182 if np.sum(np.isin(folders, ["validation", "val"])) > 0: 1183 val_indices = np.where(np.isin(folders, ["validation", "val"]))[0] 1184 else: 1185 train_videos = list(set(videos[train_indices])) 1186 val_len = int(val_frac * len(train_videos)) 1187 validation = train_videos[:val_len] 1188 training = train_videos[val_len:] 1189 train_indices = np.where(np.isin(videos, training))[0] 1190 val_indices = np.where(np.isin(videos, validation))[0] 1191 test_indices = np.where(folders == "test")[0] 1192 if save_split: 1193 self._save_partition( 1194 list(set(videos[train_indices])), 1195 list(set(videos[val_indices])), 1196 list(set(videos[test_indices])), 1197 split_path, 1198 ) 1199 elif method.startswith("leave-one-out"): 1200 n = int(method.split(":")[-1]) 1201 videos = np.array(self.input_store.get_video_id_order()) 1202 all_videos = sorted(list(set(videos))) 1203 print(len(all_videos)) 1204 validation = [all_videos.pop(n)] 1205 training = all_videos 1206 train_indices = np.where(np.isin(videos, training))[0] 1207 val_indices = np.where(np.isin(videos, validation))[0] 1208 test_indices = np.array([]) 1209 elif method.startswith("leave-one-in"): 1210 n = int(method.split(":")[1]) 1211 videos = np.array(self.input_store.get_video_id_order()) 1212 all_videos = sorted(list(set(videos))) 1213 training = [all_videos.pop(n)] 1214 validation = all_videos 1215 train_indices = np.where(np.isin(videos, training))[0] 1216 val_indices = np.where(np.isin(videos, validation))[0] 1217 test_indices = np.array([]) 1218 elif method.startswith("leave-n-in"): 1219 train_idx = [int(i) for i in method.split(":")[1].split(",")] 1220 videos = np.array(self.input_store.get_video_id_order()) 1221 all_videos = sorted(list(set(videos))) 1222 training = [v for i, v in enumerate(all_videos) if i in train_idx] 1223 validation = [v for i, v in enumerate(all_videos) if i not in train_idx] 1224 train_indices = np.where(np.isin(videos, training))[0] 1225 val_indices = np.where(np.isin(videos, validation))[0] 1226 test_indices = np.array([]) 1227 elif method.startswith("time"): 1228 if method.endswith("strict"): 1229 len_segment = self.len_segment() 1230 step = self.input_store.step 1231 num_removed = len_segment // step 1232 else: 1233 num_removed = 0 1234 videos = np.array(self.input_store.get_video_id_order()) 1235 all_videos = set(videos) 1236 train_indices = [] 1237 val_indices = [] 1238 test_indices = [] 1239 start = 0 1240 if len(method.split(":")) > 1 and method.split(":")[1] == "start-from": 1241 start = float(method.split(":")[2]) 1242 for video_id in all_videos: 1243 video_indices = np.where(videos == video_id)[0] 1244 val_len = int(val_frac * len(video_indices)) 1245 test_len = int(test_frac * len(video_indices)) 1246 start_pos = int(start * len(video_indices)) 1247 all_ind = np.ones(len(video_indices)) 1248 val_indices += list(video_indices[start_pos : start_pos + val_len]) 1249 all_ind[start_pos : start_pos + val_len] = 0 1250 if start_pos + val_len > len(video_indices): 1251 p = start_pos + val_len - len(video_indices) 1252 val_indices += list(video_indices[:p]) 1253 all_ind[:p] = 0 1254 else: 1255 p = start_pos + val_len 1256 test_indices += list(video_indices[p : p + test_len]) 1257 all_ind[p : p + test_len] = 0 1258 if p + test_len > len(video_indices): 1259 p = test_len + p - len(video_indices) 1260 test_indices += list(video_indices[:p]) 1261 all_ind[:p] = 0 1262 train_indices += list(video_indices[all_ind > 0]) 1263 for _ in range(num_removed): 1264 if len(val_indices) > 0: 1265 val_indices.pop(-1) 1266 if len(test_indices) > 0: 1267 test_indices.pop(-1) 1268 if start > 0 and len(train_indices) > 0: 1269 train_indices.pop(-1) 1270 elif method == "file": 1271 if split_path is None: 1272 raise ValueError( 1273 'You need to either set split_path or change partition method ("file" requires a file)' 1274 ) 1275 active_list = None 1276 training, validation, test = [], [], [] 1277 with open(split_path) as f: 1278 for line in f.readlines(): 1279 if line.startswith("Train"): 1280 active_list = training 1281 elif line.startswith("Valid"): 1282 active_list = validation 1283 elif line.startswith("Test"): 1284 active_list = test 1285 else: 1286 stripped_line = line.rstrip(",\n ") 1287 if stripped_line == "": 1288 continue 1289 if ", " in stripped_line: 1290 active_list += stripped_line.split(", ") 1291 else: 1292 active_list.append(stripped_line) 1293 all_lines = training + validation + test 1294 if len(all_lines[0].split("---")) == 3: 1295 entry_type = "coords" 1296 else: 1297 entry_type = "videos" 1298 1299 if entry_type == "videos": 1300 videos = np.array(self.input_store.get_video_id_order()) 1301 val_indices = np.where(np.isin(videos, validation))[0] 1302 test_indices = np.where(np.isin(videos, test))[0] 1303 train_indices = np.where(np.isin(videos, training))[0] 1304 elif entry_type == "coords": 1305 coords = self.input_store.get_original_coordinates() 1306 video_ids = self.input_store.get_video_id_order() 1307 clip_ids = [self.input_store.get_clip_id(coord) for coord in coords] 1308 starts, ends = zip( 1309 *[self.input_store.get_clip_start_end(coord) for coord in coords] 1310 ) 1311 coords = np.array( 1312 [ 1313 f"{video_id}---{clip_id}---{start}-{end}" 1314 for video_id, clip_id, start, end in zip( 1315 video_ids, clip_ids, starts, ends 1316 ) 1317 ] 1318 ) 1319 val_indices = np.where(np.isin(coords, validation))[0] 1320 test_indices = np.where(np.isin(coords, test))[0] 1321 train_indices = np.where(np.isin(coords, training))[0] 1322 else: 1323 raise ValueError("The split path has unrecognized format!") 1324 all_indices = np.ones(len(self)) 1325 if len(train_indices) == 0: 1326 all_indices[val_indices] = 0 1327 all_indices[test_indices] = 0 1328 train_indices = np.where(all_indices)[0] 1329 elif len(val_indices) == 0: 1330 all_indices[train_indices] = 0 1331 all_indices[test_indices] = 0 1332 val_indices = np.where(all_indices)[0] 1333 elif len(test_indices) == 0: 1334 all_indices[train_indices] = 0 1335 all_indices[val_indices] = 0 1336 test_indices = np.where(all_indices)[0] 1337 else: 1338 raise ValueError( 1339 f"The {method} partition is not recognized, please choose from {options.partition_methods}" 1340 ) 1341 return sorted(train_indices), sorted(test_indices), sorted(val_indices) 1342 1343 def _save_partition( 1344 self, 1345 training: List, 1346 validation: List, 1347 test: List, 1348 split_path: str, 1349 coords: bool = False, 1350 ) -> None: 1351 """Save a split file.""" 1352 if coords: 1353 name = "coords" 1354 training_coords = [] 1355 val_coords = [] 1356 test_coords = [] 1357 for coord in training: 1358 video_id = self.input_store.get_video_id(coord) 1359 clip_id = self.input_store.get_clip_id(coord) 1360 start, end = self.input_store.get_clip_start_end(coord) 1361 training_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1362 for coord in validation: 1363 video_id = self.input_store.get_video_id(coord) 1364 clip_id = self.input_store.get_clip_id(coord) 1365 start, end = self.input_store.get_clip_start_end(coord) 1366 val_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1367 for coord in test: 1368 video_id = self.input_store.get_video_id(coord) 1369 clip_id = self.input_store.get_clip_id(coord) 1370 start, end = self.input_store.get_clip_start_end(coord) 1371 test_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1372 training, validation, test = training_coords, val_coords, test_coords 1373 else: 1374 name = "videos" 1375 if split_path is not None: 1376 with open(split_path, "w") as f: 1377 f.write(f"Training {name}:\n") 1378 for x in training: 1379 f.write(x + "\n") 1380 f.write(f"Validation {name}:\n") 1381 for x in validation: 1382 f.write(x + "\n") 1383 f.write(f"Test {name}:\n") 1384 for x in test: 1385 f.write(x + "\n") 1386 1387 def _get_intervals_from_ind(self, frame_indices: np.ndarray): 1388 """Get a list of intervals from a list of frame indices. 1389 1390 Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`. 1391 1392 Parameters 1393 ---------- 1394 frame_indices : np.ndarray 1395 a list of frame indices 1396 1397 Returns 1398 ------- 1399 intervals : list 1400 a list of interval boundaries 1401 1402 """ 1403 masked_intervals = [] 1404 breaks = np.where(np.diff(frame_indices) != 1)[0] 1405 if len(frame_indices) > 0: 1406 start = frame_indices[0] 1407 for k in breaks: 1408 masked_intervals.append([start, frame_indices[k] + 1]) 1409 start = frame_indices[k + 1] 1410 masked_intervals.append([start, frame_indices[-1] + 1]) 1411 return masked_intervals 1412 1413 def get_intervals(self) -> Tuple[dict, Optional[list]]: 1414 """Get a list of intervals covered by the dataset in the original coordinates. 1415 1416 Returns 1417 ------- 1418 intervals : dict 1419 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1420 values are lists of the intervals in `[start, end]` format 1421 1422 """ 1423 counter = defaultdict(lambda: {}) 1424 coordinates = self.input_store.get_original_coordinates() 1425 for coords in coordinates: 1426 l = self.input_store.get_clip_length_from_coords(coords) 1427 video_name = self.input_store.get_video_id(coords) 1428 clip_id = self.input_store.get_clip_id(coords) 1429 start, end = self.input_store.get_clip_start_end(coords) 1430 if clip_id not in counter[video_name]: 1431 counter[video_name][clip_id] = np.zeros(l) 1432 counter[video_name][clip_id][start:end] = 1 1433 result = {video_name: {} for video_name in counter} 1434 for video_name in counter: 1435 for clip_id in counter[video_name]: 1436 result[video_name][clip_id] = self._get_intervals_from_ind( 1437 np.where(counter[video_name][clip_id])[0] 1438 ) 1439 return result, self.ids 1440 1441 def get_unannotated_intervals(self, first_intervals=None) -> Dict: 1442 """Get a list of intervals in the original coordinates where there is no annotation. 1443 1444 Parameters 1445 ---------- 1446 first_intervals : dict 1447 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1448 values are lists of the intervals in `[start, end]` format. If provided, only the intersection with 1449 those intervals will be returned 1450 1451 Returns 1452 ------- 1453 intervals : dict 1454 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1455 values are lists of the intervals in `[start, end]` format 1456 1457 """ 1458 counter_value = 2 1459 if first_intervals is None: 1460 first_intervals = defaultdict(lambda: defaultdict(lambda: [])) 1461 counter_value = 1 1462 counter = defaultdict(lambda: {}) 1463 coordinates = self.input_store.get_original_coordinates() 1464 for i, coords in enumerate(coordinates): 1465 l = self.input_store.get_clip_length_from_coords(coords) 1466 ann = self.annotation_store[i] 1467 if ( 1468 self.annotation_store.annotation_class() 1469 == "nonexclusive_classification" 1470 ): 1471 ann = ann[0, :] 1472 video_name = self.input_store.get_video_id(coords) 1473 clip_id = self.input_store.get_clip_id(coords) 1474 start, end = self.input_store.get_clip_start_end(coords) 1475 if clip_id not in counter[video_name]: 1476 counter[video_name][clip_id] = np.ones(l) 1477 counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int() 1478 result = {video_name: {} for video_name in counter} 1479 for video_name in counter: 1480 for clip_id in counter[video_name]: 1481 for start, end in first_intervals[video_name][clip_id]: 1482 counter[video_name][clip_id][start:end] += 1 1483 result[video_name][clip_id] = self._get_intervals_from_ind( 1484 np.where(counter[video_name][clip_id] == counter_value)[0] 1485 ) 1486 return result 1487 1488 def get_annotated_intervals(self) -> Dict: 1489 """Get a list of intervals in the original coordinates where there is no annotation. 1490 1491 Returns 1492 ------- 1493 intervals : dict 1494 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1495 values are lists of the intervals in `[start, end]` format 1496 1497 """ 1498 if self.annotation_type == "none": 1499 return [] 1500 counter_value = 1 1501 counter = defaultdict(lambda: {}) 1502 coordinates = self.input_store.get_original_coordinates() 1503 for i, coords in enumerate(coordinates): 1504 l = self.input_store.get_clip_length_from_coords(coords) 1505 ann = self.annotation_store[i] 1506 video_name = self.input_store.get_video_id(coords) 1507 clip_id = self.input_store.get_clip_id(coords) 1508 start, end = self.input_store.get_clip_start_end(coords) 1509 if clip_id not in counter[video_name]: 1510 counter[video_name][clip_id] = np.zeros(l) 1511 if ( 1512 self.annotation_store.annotation_class() 1513 == "nonexclusive_classification" 1514 ): 1515 counter[video_name][clip_id][start:end] = ( 1516 torch.sum(ann[:, : end - start] != -100, dim=0) > 0 1517 ).int() 1518 else: 1519 counter[video_name][clip_id][start:end] = ( 1520 ann[: end - start] != -100 1521 ).int() 1522 result = {video_name: {} for video_name in counter} 1523 for video_name in counter: 1524 for clip_id in counter[video_name]: 1525 result[video_name][clip_id] = self._get_intervals_from_ind( 1526 np.where(counter[video_name][clip_id] == counter_value)[0] 1527 ) 1528 return result 1529 1530 def get_ids(self) -> Dict: 1531 """Get a dictionary of all clip ids in the dataset. 1532 1533 Returns 1534 ------- 1535 ids : dict 1536 a dictionary where keys are video ids and values are lists of clip ids 1537 1538 """ 1539 coordinates = self.input_store.get_original_coordinates() 1540 video_ids = np.array(self.input_store.get_video_id_order()) 1541 id_set = set(video_ids) 1542 result = {} 1543 for video_id in id_set: 1544 coords = coordinates[video_ids == video_id] 1545 clip_ids = list({self.input_store.get_clip_id(c) for c in coords}) 1546 result[video_id] = clip_ids 1547 return result 1548 1549 def get_len(self, video_id: str, clip_id: str) -> int: 1550 """Get the length of a specific clip. 1551 1552 Parameters 1553 ---------- 1554 video_id : str 1555 the video id 1556 clip_id : str 1557 the clip id 1558 1559 Returns 1560 ------- 1561 length : int 1562 the length 1563 1564 """ 1565 return self.input_store.get_clip_length(video_id, clip_id) 1566 1567 def get_confusion_matrix( 1568 self, prediction: torch.Tensor, confusion_type: str = "recall" 1569 ) -> Tuple[ndarray, list]: 1570 """Get a confusion matrix. 1571 1572 Parameters 1573 ---------- 1574 prediction : torch.Tensor 1575 a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)` 1576 confusion_type : {"recall", "precision"} 1577 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 1578 into account, and if `type` is `"precision"`, only false negatives 1579 1580 Returns 1581 ------- 1582 confusion_matrix : np.ndarray 1583 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 1584 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 1585 `N_i` is the number of frames that have the i-th label in the ground truth 1586 classes : list 1587 a list of classes 1588 1589 """ 1590 behaviors_dict = self.annotation_store.behaviors_dict() 1591 num_behaviors = len(behaviors_dict) 1592 confusion_matrix = np.zeros((num_behaviors, num_behaviors)) 1593 if self.annotation_store.annotation_class() == "exclusive_classification": 1594 exclusive = True 1595 confusion_type = None 1596 elif self.annotation_store.annotation_class() == "nonexclusive_classification": 1597 exclusive = False 1598 else: 1599 raise RuntimeError( 1600 f"The {self.annotation_store.annotation_class()} annotation class is not recognized!" 1601 ) 1602 for ann, p in zip(self.annotation_store, prediction): 1603 if exclusive: 1604 class_prediction = torch.max(p, dim=0)[1] 1605 for i in behaviors_dict.keys(): 1606 for j in behaviors_dict.keys(): 1607 confusion_matrix[i, j] += int( 1608 torch.sum(class_prediction[ann == i] == j) 1609 ) 1610 else: 1611 class_prediction = (p > 0.5).int() 1612 for i in behaviors_dict.keys(): 1613 for j in behaviors_dict.keys(): 1614 if confusion_type == "recall": 1615 pred = deepcopy(class_prediction[j]) 1616 if i != j: 1617 pred[ann[j] == 1] = 0 1618 confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1])) 1619 elif confusion_type == "precision": 1620 annotation = deepcopy(ann[j]) 1621 if i != j: 1622 annotation[class_prediction[j] == 1] = 0 1623 confusion_matrix[i, j] += int( 1624 torch.sum(annotation[class_prediction[i] == 1]) 1625 ) 1626 else: 1627 raise ValueError( 1628 f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']" 1629 ) 1630 counter = self.annotation_store.count_classes() 1631 for i in behaviors_dict.keys(): 1632 if counter[i] != 0: 1633 if confusion_type == "recall" or confusion_type is None: 1634 confusion_matrix[i, :] /= counter[i] 1635 else: 1636 confusion_matrix[:, i] /= counter[i] 1637 return confusion_matrix, list(behaviors_dict.values()), confusion_type 1638 1639 def annotation_class(self) -> str: 1640 """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon). 1641 1642 Returns 1643 ------- 1644 annotation_class : str 1645 the type of annotation 1646 1647 """ 1648 return self.annotation_store.annotation_class() 1649 1650 def set_normalization_stats(self, stats: Dict) -> None: 1651 """Set the stats to normalize data at runtime. 1652 1653 Parameters 1654 ---------- 1655 stats : dict 1656 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1657 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1658 1659 """ 1660 self.stats = stats 1661 1662 def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]: 1663 """Get the minimum and maximum frame numbers for each clip in a video. 1664 1665 Parameters 1666 ---------- 1667 video_id : str 1668 the video id 1669 1670 Returns 1671 ------- 1672 min_frames : dict 1673 a dictionary where keys are clip ids and values are the minimum frame numbers 1674 max_frames : dict 1675 a dictionary where keys are clip ids and values are the maximum frame numbers 1676 1677 """ 1678 coords = self.input_store.get_original_coordinates() 1679 clips = set( 1680 [ 1681 self.input_store.get_clip_id(c) 1682 for c in coords 1683 if self.input_store.get_video_id(c) == video_id 1684 ] 1685 ) 1686 min_frames = {} 1687 max_frames = {} 1688 for clip in clips: 1689 start = self.input_store.get_clip_start(video_id, clip) 1690 end = start + self.input_store.get_clip_length(video_id, clip) 1691 min_frames[clip] = start 1692 max_frames[clip] = end - 1 1693 return min_frames, max_frames 1694 1695 def get_normalization_stats(self, skip_keys=None) -> Dict: 1696 """Get mean and standard deviation for each key. 1697 1698 Parameters 1699 ---------- 1700 skip_keys : list, optional 1701 a list of keys to skip 1702 1703 Returns 1704 ------- 1705 stats : dict 1706 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1707 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1708 1709 """ 1710 stats = defaultdict(lambda: {}) 1711 sums = defaultdict(lambda: 0) 1712 if skip_keys is None: 1713 skip_keys = [] 1714 counter = defaultdict(lambda: 0) 1715 for sample in tqdm(self): 1716 for key, value in sample["input"].items(): 1717 key_name = key.split("---")[0] 1718 if key_name not in skip_keys: 1719 sums[key_name] += value[:, value.sum(0) != 0].sum(-1) 1720 counter[key_name] += torch.sum(value.sum(0) != 0) 1721 for key, value in sums.items(): 1722 stats[key]["mean"] = (value / counter[key]).unsqueeze(-1) 1723 sums = defaultdict(lambda: 0) 1724 for sample in tqdm(self): 1725 for key, value in sample["input"].items(): 1726 key_name = key.split("---")[0] 1727 if key_name not in skip_keys: 1728 sums[key_name] += ( 1729 (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2 1730 ).sum(-1) 1731 for key, value in sums.items(): 1732 stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key]) 1733 return stats
A generalized dataset class.
Data and annotation are stored in separate InputStore and BehaviorStore objects; the dataset class manages their interactions.
40 def __init__( 41 self, 42 data_type: str, 43 annotation_type: str = "none", 44 ssl_transformations: List = None, 45 saved_data_path: str = None, 46 input_store: InputStore = None, 47 annotation_store: BehaviorStore = None, 48 only_load_annotated: bool = False, 49 recompute_annotation: bool = False, 50 # mask: str = None, 51 ids: List = None, 52 **data_parameters, 53 ) -> None: 54 """Initialize a dataset. 55 56 Parameters 57 ---------- 58 data_type : str 59 the data type (see available types by running BehaviorDataset.data_types()) 60 annotation_type : str 61 the annotation type (see available types by running BehaviorDataset.annotation_types()) 62 ssl_transformations : list 63 a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple 64 saved_data_path : str 65 the path to a pre-computed pickled dataset 66 input_store : InputStore 67 a pre-computed input store 68 annotation_store : BehaviorStore 69 a precomputed annotation store 70 only_load_annotated : bool 71 if `True`, the input files that don't have a matching annotation file will be disregarded 72 recompute_annotation : bool 73 if `True`, the annotation will be recomputed even if a precomputed annotation store is provided 74 ids : list 75 a list of ids to load from the input store 76 *data_parameters : dict 77 parameters to initialize the input and annotation stores 78 79 """ 80 mask = None 81 if len(data_parameters) == 0: 82 recompute_annotation = False 83 feature_extraction = data_parameters.get("feature_extraction") 84 if feature_extraction is not None and not issubclass( 85 options.input_stores[data_type], 86 options.feature_extractors[feature_extraction].input_store_class, 87 ): 88 raise ValueError( 89 f"The {feature_extraction} feature extractor does not work with " 90 f"the {data_type} data type, please choose a suclass of " 91 f"{options.feature_extractors[feature_extraction].input_store_class}" 92 ) 93 if ssl_transformations is None: 94 ssl_transformations = [] 95 self.ssl_transformations = ssl_transformations 96 self.input_type = data_type 97 self.annotation_type = annotation_type 98 self.stats = None 99 if mask is not None: 100 with open(mask, "rb") as f: 101 self.mask = pickle.load(f) 102 else: 103 self.mask = None 104 self.ids = ids 105 self.tag = None 106 self.return_unlabeled = None 107 # load saved key objects for annotation and input if they exist 108 input_key_objects, annotation_key_objects = None, None 109 if saved_data_path is not None: 110 if os.path.exists(saved_data_path): 111 with open(saved_data_path, "rb") as f: 112 input_key_objects, annotation_key_objects = pickle.load(f) 113 # if the input or the annotation store need to be created, generate the common video order 114 if len(data_parameters) > 0: 115 input_files = options.input_stores[data_type].get_file_ids( 116 **data_parameters 117 ) 118 annotation_files = options.annotation_stores[annotation_type].get_file_ids( 119 **data_parameters 120 ) 121 if only_load_annotated: 122 data_parameters["video_order"] = [ 123 x for x in input_files if x in annotation_files 124 ] 125 else: 126 data_parameters["video_order"] = input_files 127 if len(data_parameters["video_order"]) == 0: 128 raise RuntimeError( 129 "The length of file list is 0! Please check your data parameters!" 130 ) 131 data_parameters["mask"] = self.mask 132 # load or create the input store 133 ok = False 134 if input_store is not None: 135 self.input_store = input_store 136 ok = True 137 elif input_key_objects is not None: 138 try: 139 self.input_store = self._load_input_store(data_type, input_key_objects) 140 ok = True 141 except: 142 warnings.warn("Loading input store from key objects failed") 143 if not ok: 144 self.input_store = self._get_input_store( 145 data_type, deepcopy(data_parameters) 146 ) 147 # get the objects needed to create the annotation store (like a clip length dictionary) 148 annotation_objects = self.input_store.get_annotation_objects() 149 data_parameters.update(annotation_objects) 150 # load or create the annotation store 151 ok = False 152 if annotation_store is not None: 153 self.annotation_store = annotation_store 154 ok = True 155 elif ( 156 (annotation_key_objects is not None) 157 and mask is None 158 and not recompute_annotation 159 ): 160 if len(annotation_key_objects) > 0: 161 try: 162 self.annotation_store = self._load_annotation_store( 163 annotation_type, annotation_key_objects 164 ) 165 ok = True 166 except: 167 warnings.warn("Loading annotation store from key objects failed") 168 if not ok: 169 self.annotation_store = self._get_annotation_store( 170 annotation_type, deepcopy(data_parameters) 171 ) 172 to_remove = self.annotation_store.filtered_indices() 173 if len(to_remove) > 0: 174 print( 175 f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples" 176 ) 177 if len(to_remove) == len(self.annotation_store) and len(to_remove) > 0: 178 raise ValueError("All samples were filtered out!") 179 180 if len(self.input_store) == len(self.annotation_store): 181 self.input_store.remove(to_remove) 182 self.annotation_store.remove(to_remove) 183 self.input_indices = list(range(len(self.input_store))) 184 self.annotation_indices = list(range(len(self.input_store))) 185 self.indices = list(range(len(self.input_store)))
Initialize a dataset.
Parameters
data_type : str
the data type (see available types by running BehaviorDataset.data_types())
annotation_type : str
the annotation type (see available types by running BehaviorDataset.annotation_types())
ssl_transformations : list
a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple
saved_data_path : str
the path to a pre-computed pickled dataset
input_store : InputStore
a pre-computed input store
annotation_store : BehaviorStore
a precomputed annotation store
only_load_annotated : bool
if True
, the input files that don't have a matching annotation file will be disregarded
recompute_annotation : bool
if True
, the annotation will be recomputed even if a precomputed annotation store is provided
ids : list
a list of ids to load from the input store
*data_parameters : dict
parameters to initialize the input and annotation stores
229 def save(self, save_path: str) -> None: 230 """Save the dictionary. 231 232 Parameters 233 ---------- 234 save_path : str 235 the path where the pickled file will be stored 236 237 """ 238 input_obj = self.input_store.key_objects() 239 annotation_obj = self.annotation_store.key_objects() 240 with open(save_path, "wb") as f: 241 pickle.dump((input_obj, annotation_obj), f)
Save the dictionary.
Parameters
save_path : str the path where the pickled file will be stored
243 def to_ram(self) -> None: 244 """Transfer the dataset to RAM.""" 245 self.input_store.to_ram() 246 self.annotation_store.to_ram()
Transfer the dataset to RAM.
248 def generate_full_length_gt(self) -> Dict: 249 """Generate full-length ground truth from the annotations. 250 251 Returns 252 ------- 253 full_length_gt : dict 254 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 255 values are the ground truth labels 256 257 """ 258 if self.annotation_class() == "exclusive_classification": 259 gt = torch.zeros((len(self), self.len_segment())) 260 else: 261 gt = torch.zeros( 262 (len(self), len(self.behaviors_dict()), self.len_segment()) 263 ) 264 for i in range(len(self)): 265 gt[i] = self.annotation_store[i] 266 return self.generate_full_length_prediction(gt)
Generate full-length ground truth from the annotations.
Returns
full_length_gt : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are the ground truth labels
268 def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict: 269 """Map predictions for the equal-length pieces to predictions for the original data. 270 271 Probabilities are averaged over predictions on overlapping intervals. 272 273 Parameters 274 ---------- 275 predicted: torch.Tensor 276 a tensor of predicted probabilities of shape `(N, #classes, #frames)` 277 278 Returns 279 ------- 280 full_length_prediction : dict 281 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are 282 averaged probability tensors 283 284 """ 285 result = defaultdict(lambda: {}) 286 counter = defaultdict(lambda: {}) 287 coordinates = self.input_store.get_original_coordinates() 288 for coords, prediction in zip(coordinates, predicted): 289 l = self.input_store.get_clip_length_from_coords(coords) 290 video_name = self.input_store.get_video_id(coords) 291 clip_id = self.input_store.get_clip_id(coords) 292 start, end = self.input_store.get_clip_start_end(coords) 293 if clip_id not in result[video_name].keys(): 294 result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 295 counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 296 result[video_name][clip_id][..., start:end] += ( 297 prediction.squeeze()[..., : end - start].detach().cpu() 298 ) 299 counter[video_name][clip_id][..., start:end] += 1 300 for video_name in result: 301 for clip_id in result[video_name]: 302 result[video_name][clip_id] /= counter[video_name][clip_id] 303 result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100 304 result = dict(result) 305 return result
Map predictions for the equal-length pieces to predictions for the original data.
Probabilities are averaged over predictions on overlapping intervals.
Parameters
predicted: torch.Tensor
a tensor of predicted probabilities of shape (N, #classes, #frames)
Returns
full_length_prediction : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are averaged probability tensors
307 def find_valleys( 308 self, 309 predicted: Union[torch.Tensor, Dict], 310 threshold: float = 0.5, 311 min_frames: int = 0, 312 visibility_min_score: float = 0, 313 visibility_min_frac: float = 0, 314 main_class: int = 1, 315 low: bool = True, 316 predicted_error: torch.Tensor = None, 317 error_threshold: float = 0.5, 318 hysteresis: bool = False, 319 threshold_diff: float = None, 320 min_frames_error: int = None, 321 smooth_interval: int = 1, 322 cut_annotated: bool = False, 323 ) -> Dict: 324 """Find the intervals where the probability of a certain class is below or above a certain hard_threshold. 325 326 Parameters 327 ---------- 328 predicted : torch.Tensor | dict 329 either a tensor of predictions for the data prompts or the output of 330 `BehaviorDataset.generate_full_length_prediction` 331 threshold : float, default 0.5 332 the main hard_threshold 333 min_frames : int, default 0 334 the minimum length of the intervals 335 visibility_min_score : float, default 0 336 the minimum visibility score in the intervals 337 visibility_min_frac : float, default 0 338 fraction of the interval that has to have the visibility score larger than visibility_score_thr 339 main_class : int, default 1 340 the index of the class the function is inspecting 341 low : bool, default True 342 if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above 343 predicted_error : torch.Tensor, optional 344 a tensor of error predictions for the data prompts 345 error_threshold : float, default 0.5 346 maximum possible probability of error at the intervals 347 hysteresis: bool, default False 348 if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff 349 threshold_diff: float, optional 350 the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes 351 min_frames_error: int, optional 352 if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames 353 smooth_interval: int, default 1 354 the number of frames to smooth the predictions over 355 cut_annotated: bool, default False 356 if `True`, annotated intervals will be cut out of the predicted intervals 357 358 Returns 359 ------- 360 valleys : dict 361 a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals 362 363 """ 364 result = defaultdict(lambda: []) 365 if type(predicted) is not dict: 366 predicted = self.generate_full_length_prediction(predicted) 367 if predicted_error is not None: 368 predicted_error = self.generate_full_length_prediction(predicted_error) 369 elif min_frames_error is not None and min_frames_error != 0: 370 # warnings.warn( 371 # f"The min_frames_error parameter is set to {min_frames_error} but no error prediction " 372 # f"is given! Setting min_frames_error to 0." 373 # ) 374 min_frames_error = 0 375 if low and hysteresis and threshold_diff is None: 376 raise ValueError( 377 "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff." 378 ) 379 if cut_annotated: 380 masked_intervals_dict = self.get_annotated_intervals() 381 else: 382 masked_intervals_dict = None 383 print("Valleys found:") 384 for v_id in predicted: 385 for clip_id in predicted[v_id].keys(): 386 if predicted_error is not None: 387 error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold 388 if min_frames_error is not None: 389 output, indices, counts = torch.unique_consecutive( 390 error_mask, return_inverse=True, return_counts=True 391 ) 392 wrong_indices = torch.where( 393 output * (counts < min_frames_error) 394 )[0] 395 if len(wrong_indices) > 0: 396 for i in wrong_indices: 397 error_mask[indices == i] = False 398 else: 399 error_mask = None 400 if masked_intervals_dict is not None: 401 masked_intervals = masked_intervals_dict[v_id][clip_id] 402 else: 403 masked_intervals = None 404 if not hysteresis: 405 res_indices_start, res_indices_end = apply_threshold( 406 predicted[v_id][clip_id][main_class, :], 407 threshold, 408 low, 409 error_mask, 410 min_frames, 411 smooth_interval, 412 masked_intervals, 413 ) 414 elif threshold_diff is not None: 415 if low: 416 soft_threshold = threshold + threshold_diff 417 else: 418 soft_threshold = threshold - threshold_diff 419 res_indices_start, res_indices_end = apply_threshold_hysteresis( 420 predicted[v_id][clip_id][main_class, :], 421 soft_threshold, 422 threshold, 423 low, 424 error_mask, 425 min_frames, 426 smooth_interval, 427 masked_intervals, 428 ) 429 else: 430 res_indices_start, res_indices_end = apply_threshold_max( 431 predicted[v_id][clip_id], 432 threshold, 433 main_class, 434 error_mask, 435 min_frames, 436 smooth_interval, 437 masked_intervals, 438 ) 439 start = self.input_store.get_clip_start(v_id, clip_id) 440 result[v_id] += [ 441 [i + start, j + start, clip_id] 442 for i, j in zip(res_indices_start, res_indices_end) 443 if self.input_store.get_visibility( 444 v_id, clip_id, i, j, visibility_min_score 445 ) 446 > visibility_min_frac 447 ] 448 result[v_id] = sorted(result[v_id]) 449 print(f" {v_id}: {len(result[v_id])}") 450 return dict(result)
Find the intervals where the probability of a certain class is below or above a certain hard_threshold.
Parameters
predicted : torch.Tensor | dict
either a tensor of predictions for the data prompts or the output of
BehaviorDataset.generate_full_length_prediction
threshold : float, default 0.5
the main hard_threshold
min_frames : int, default 0
the minimum length of the intervals
visibility_min_score : float, default 0
the minimum visibility score in the intervals
visibility_min_frac : float, default 0
fraction of the interval that has to have the visibility score larger than visibility_score_thr
main_class : int, default 1
the index of the class the function is inspecting
low : bool, default True
if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above
predicted_error : torch.Tensor, optional
a tensor of error predictions for the data prompts
error_threshold : float, default 0.5
maximum possible probability of error at the intervals
hysteresis: bool, default False
if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff
threshold_diff: float, optional
the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes
min_frames_error: int, optional
if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames
smooth_interval: int, default 1
the number of frames to smooth the predictions over
cut_annotated: bool, default False
if True
, annotated intervals will be cut out of the predicted intervals
Returns
valleys : dict a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals
452 def valleys_union(self, valleys_list) -> Dict: 453 """Find the intersection of two valleys dictionaries. 454 455 Parameters 456 ---------- 457 valleys_list : list 458 a list of valleys dictionaries 459 460 Returns 461 ------- 462 intersection : dict 463 a new valleys dictionary with the intersection of the input intervals 464 465 """ 466 valleys_list = [x for x in valleys_list if x is not None] 467 if len(valleys_list) == 1: 468 return valleys_list[0] 469 elif len(valleys_list) == 0: 470 return {} 471 union = {} 472 keys_list = [set(valleys.keys()) for valleys in valleys_list] 473 keys = set.union(*keys_list) 474 for v_id in keys: 475 res = [] 476 clips_list = [ 477 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 478 ] 479 clips = set.union(*clips_list) 480 for clip_id in clips: 481 clip_intervals = [ 482 x 483 for valleys in valleys_list 484 for x in valleys[v_id] 485 if x[-1] == clip_id 486 ] 487 v_len = self.input_store.get_clip_length(v_id, clip_id) 488 arr = torch.zeros(v_len) 489 for start, end, _ in clip_intervals: 490 arr[start:end] += 1 491 output, indices, counts = torch.unique_consecutive( 492 arr > 0, return_inverse=True, return_counts=True 493 ) 494 long_indices = torch.where(output)[0] 495 res += [ 496 ( 497 (indices == i).nonzero(as_tuple=True)[0][0].item(), 498 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 499 clip_id, 500 ) 501 for i in long_indices 502 ] 503 union[v_id] = res 504 return union
Find the intersection of two valleys dictionaries.
Parameters
valleys_list : list a list of valleys dictionaries
Returns
intersection : dict a new valleys dictionary with the intersection of the input intervals
506 def valleys_intersection(self, valleys_list) -> Dict: 507 """Find the intersection of two valleys dictionaries. 508 509 Parameters 510 ---------- 511 valleys_list : list 512 a list of valleys dictionaries 513 514 Returns 515 ------- 516 intersection : dict 517 a new valleys dictionary with the intersection of the input intervals 518 519 """ 520 valleys_list = [x for x in valleys_list if x is not None] 521 if len(valleys_list) == 1: 522 return valleys_list[0] 523 elif len(valleys_list) == 0: 524 return {} 525 intersection = {} 526 keys_list = [set(valleys.keys()) for valleys in valleys_list] 527 keys = set.intersection(*keys_list) 528 for v_id in keys: 529 res = [] 530 clips_list = [ 531 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 532 ] 533 clips = set.intersection(*clips_list) 534 for clip_id in clips: 535 clip_intervals = [ 536 x 537 for valleys in valleys_list 538 for x in valleys[v_id] 539 if x[-1] == clip_id 540 ] 541 v_len = self.input_store.get_clip_length(v_id, clip_id) 542 arr = torch.zeros(v_len) 543 for start, end, _ in clip_intervals: 544 arr[start:end] += 1 545 output, indices, counts = torch.unique_consecutive( 546 arr, return_inverse=True, return_counts=True 547 ) 548 long_indices = torch.where(output == 2)[0] 549 res += [ 550 ( 551 (indices == i).nonzero(as_tuple=True)[0][0].item(), 552 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 553 clip_id, 554 ) 555 for i in long_indices 556 ] 557 intersection[v_id] = res 558 return intersection
Find the intersection of two valleys dictionaries.
Parameters
valleys_list : list a list of valleys dictionaries
Returns
intersection : dict a new valleys dictionary with the intersection of the input intervals
560 def partition_train_test_val( 561 self, 562 use_test: float = 0, 563 split_path: str = None, 564 method: str = "random", 565 val_frac: float = 0, 566 test_frac: float = 0, 567 save_split: bool = False, 568 normalize: bool = False, 569 skip_normalization_keys: List = None, 570 stats: Dict = None, 571 ) -> Tuple: 572 """Partition the dataset into three new datasets. 573 574 Parameters 575 ---------- 576 use_test : float, default 0 577 The fraction of the test dataset to be used in training without labels 578 split_path : str, optional 579 The path to load the split information from (if `'file'` method is used) and to save it to 580 (if `'save_split'` is `True`) 581 method : {'random', 'random:test-from-name', 'random:test-from-name:{name}', 582 'val-from-name:{val_name}:test-from-name:{test_name}', 583 'random:equalize:segments', 'random:equalize:videos', 584 'folders', 'time', 'time:strict', 'file'} 585 The partitioning method: 586 - `'random'`: sort videos into subsets randomly, 587 - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation 588 subsets randomly and create 589 the test subset from the video ids that start with a speific substring (`'test'` by default, or `name` 590 if provided), 591 - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but 592 making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain 593 occurrences of the class get into the validation subset and `0.8 * test_frac` get into the test subset; 594 this in ensured for all classes in order of increasing number of occurrences until the validation and 595 test subsets are full 596 - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test 597 subsets from the video ids that start with specific substrings (`val_name` for validation 598 and `test_name` for test) and sort all other videos into the training subset 599 - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets, 600 - `'time'`: split each video into training, validation and test subsequences, 601 - `'time:strict'`: split each video into validation, test and training subsequences 602 and throw out the last segments in validation and test (to get rid of overlaps), 603 - `'file'`: split according to a split file. 604 val_frac : float, default 0 605 The fraction of the dataset to be used in validation 606 test_frac : float, default 0 607 The fraction of the dataset to be used in test 608 save_split : bool, default False 609 Save a split file if True 610 normalize : bool, default False 611 Normalize the dataset if `True` 612 skip_normalization_keys : list, optional 613 A list of keys to skip normalization for 614 stats : dict, optional 615 A dictionary of (pre-computed) statistics to use for normalization 616 617 Returns 618 ------- 619 train_dataset : BehaviorDataset 620 train dataset 621 val_dataset : BehaviorDataset 622 validation dataset 623 test_dataset : BehaviorDataset 624 test dataset 625 626 """ 627 train_indices, test_indices, val_indices = self._partition_indices( 628 split_path=split_path, 629 method=method, 630 val_frac=val_frac, 631 test_frac=test_frac, 632 save_split=save_split, 633 ) 634 ssl_indices = None 635 partition_method = method.split(":") 636 if ( 637 partition_method[0] in ("leave-one-in", "leave-n-in") 638 and len(partition_method) > 1 639 and partition_method[2] == "val-for-ssl" 640 ): 641 print("Using validation samples for SSL!") 642 ssl_indices = val_indices 643 644 val_dataset = self._create_new_dataset(val_indices) 645 test_dataset = self._create_new_dataset(test_indices) 646 train_dataset = self._create_new_dataset(train_indices, ssl_indices=ssl_indices) 647 648 train_classes = train_dataset.count_classes() 649 val_classes = val_dataset.count_classes() 650 test_classes = test_dataset.count_classes() 651 print("Number of samples:") 652 print(f" validation:") 653 print(f" {[f'{k}: {val_classes[k]}' for k in sorted(val_classes.keys())]}") 654 print(f" training:") 655 print(f" {[f'{k}: {train_classes[k]}' for k in sorted(train_classes.keys())]}") 656 print(f" test:") 657 print(f" {[f'{k}: {test_classes[k]}' for k in sorted(test_classes.keys())]}") 658 if normalize: 659 if stats is None: 660 print("Computing normalization statistics...") 661 stats = train_dataset.get_normalization_stats(skip_normalization_keys) 662 else: 663 print("Setting loaded normalization statistics...") 664 train_dataset.set_normalization_stats(stats) 665 val_dataset.set_normalization_stats(stats) 666 test_dataset.set_normalization_stats(stats) 667 return train_dataset, test_dataset, val_dataset
Partition the dataset into three new datasets.
Parameters
use_test : float, default 0
The fraction of the test dataset to be used in training without labels
split_path : str, optional
The path to load the split information from (if 'file'
method is used) and to save it to
(if 'save_split'
is True
)
method : {'random', 'random:test-from-name', 'random:test-from-name:{name}',
'val-from-name:{val_name}:test-from-name:{test_name}',
'random:equalize:segments', 'random:equalize:videos',
'folders', 'time', 'time:strict', 'file'}
The partitioning method:
- 'random'
: sort videos into subsets randomly,
- 'random:test-from-name'
(or 'random:test-from-name:{name}'
): sort videos into training and validation
subsets randomly and create
the test subset from the video ids that start with a speific substring ('test'
by default, or name
if provided),
- 'random:equalize:segments'
and 'random:equalize:videos'
: sort videos into subsets randomly but
making sure that for the rarest classes at least 0.8 * val_frac
of the videos/segments that contain
occurrences of the class get into the validation subset and 0.8 * test_frac
get into the test subset;
this in ensured for all classes in order of increasing number of occurrences until the validation and
test subsets are full
- 'val-from-name:{val_name}:test-from-name:{test_name}'
: create the validation and test
subsets from the video ids that start with specific substrings (val_name
for validation
and test_name
for test) and sort all other videos into the training subset
- 'folders'
: read videos from folders named test, train and val into corresponding subsets,
- 'time'
: split each video into training, validation and test subsequences,
- 'time:strict'
: split each video into validation, test and training subsequences
and throw out the last segments in validation and test (to get rid of overlaps),
- 'file'
: split according to a split file.
val_frac : float, default 0
The fraction of the dataset to be used in validation
test_frac : float, default 0
The fraction of the dataset to be used in test
save_split : bool, default False
Save a split file if True
normalize : bool, default False
Normalize the dataset if True
skip_normalization_keys : list, optional
A list of keys to skip normalization for
stats : dict, optional
A dictionary of (pre-computed) statistics to use for normalization
Returns
train_dataset : BehaviorDataset train dataset val_dataset : BehaviorDataset validation dataset test_dataset : BehaviorDataset test dataset
669 def class_weights(self, proportional=False) -> List: 670 """Calculate class weights in inverse proportion to number of samples. 671 672 Parameters 673 ---------- 674 proportional : bool, default False 675 If `True`, the weights are proportional to the number of samples in the most common class 676 677 Returns 678 ------- 679 weights: list 680 a list of class weights 681 682 """ 683 items = sorted( 684 [ 685 (k, v) 686 for k, v in self.annotation_store.count_classes().items() 687 if k != -100 688 ] 689 ) 690 if self.annotation_store.annotation_class() == "exclusive_classification": 691 if not proportional: 692 numerator = len(self.annotation_store) 693 else: 694 numerator = max([x[1] for x in items]) 695 weights = [numerator / (v + 1e-7) for _, v in items] 696 else: 697 items_zero = sorted( 698 [ 699 (k, v) 700 for k, v in self.annotation_store.count_classes(zeros=True).items() 701 if k != -100 702 ] 703 ) 704 if not proportional: 705 numerators = defaultdict(lambda: len(self.annotation_store)) 706 else: 707 numerators = { 708 item_one[0]: max(item_one[1], item_zero[1]) 709 for item_one, item_zero in zip(items, items_zero) 710 } 711 weights = {} 712 weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero] 713 weights[1] = [numerators[k] / (v + 1e-7) for k, v in items] 714 return weights
Calculate class weights in inverse proportion to number of samples.
Parameters
proportional : bool, default False
If True
, the weights are proportional to the number of samples in the most common class
Returns
weights: list a list of class weights
737 def count_classes(self, bouts: bool = False) -> Dict: 738 """Get a class counter dictionary. 739 740 Parameters 741 ---------- 742 bouts : bool, default False 743 if `True`, instead of frame counts segment counts are returned 744 745 Returns 746 ------- 747 count_dictionary : dict 748 a dictionary with class indices as keys and frame or bout counts as values 749 750 """ 751 return self.annotation_store.count_classes(bouts=bouts)
Get a class counter dictionary.
Parameters
bouts : bool, default False
if True
, instead of frame counts segment counts are returned
Returns
count_dictionary : dict a dictionary with class indices as keys and frame or bout counts as values
753 def behaviors_dict(self) -> Dict: 754 """Get a behavior dictionary. 755 756 Returns 757 ------- 758 dict 759 behavior dictionary 760 761 """ 762 return self.annotation_store.behaviors_dict()
Get a behavior dictionary.
Returns
dict behavior dictionary
764 def bodyparts_order(self) -> List: 765 """Get the order of bodyparts. 766 767 Returns 768 ------- 769 bodyparts : List 770 a list of bodyparts 771 772 """ 773 try: 774 return self.input_store.get_bodyparts() 775 except: 776 raise RuntimeError( 777 f"The {self.input_type} input store does not have bodyparts implemented!" 778 )
Get the order of bodyparts.
Returns
bodyparts : List a list of bodyparts
780 def features_shape(self) -> Dict: 781 """Get the shapes of the input features. 782 783 Returns 784 ------- 785 shapes : Dict 786 a dictionary with the shapes of the features 787 788 """ 789 sample = self.input_store[0] 790 shapes = {k: v.shape for k, v in sample.items()} 791 # for key, value in shapes.items(): 792 # print(f'{key}: {value}') 793 return shapes
Get the shapes of the input features.
Returns
shapes : Dict a dictionary with the shapes of the features
795 def num_classes(self) -> int: 796 """Get the number of classes in the data. 797 798 Returns 799 ------- 800 num_classes : int 801 the number of classes 802 803 """ 804 return len(self.annotation_store.behaviors_dict())
Get the number of classes in the data.
Returns
num_classes : int the number of classes
806 def len_segment(self) -> int: 807 """Get the segment length in the data. 808 809 Returns 810 ------- 811 len_segment : int 812 the segment length 813 814 """ 815 sample = self.input_store[0] 816 key = list(sample.keys())[0] 817 return sample[key].shape[-1]
Get the segment length in the data.
Returns
len_segment : int the segment length
819 def set_ssl_transformations(self, ssl_transformations: List) -> None: 820 """Set new SSL transformations. 821 822 Parameters 823 ---------- 824 ssl_transformations : list 825 a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets 826 lists 827 828 """ 829 self.ssl_transformations = ssl_transformations
Set new SSL transformations.
Parameters
ssl_transformations : list a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets lists
831 @classmethod 832 def new(cls, *args, **kwargs): 833 """Create a new object of the same class. 834 835 Parameters 836 ---------- 837 args : list 838 arguments for the constructor 839 kwargs : dict 840 keyword arguments for the constructor 841 842 Returns 843 ------- 844 new_instance: BehaviorDataset 845 a new instance of the same class 846 847 """ 848 return cls(*args, **kwargs)
Create a new object of the same class.
Parameters
args : list arguments for the constructor kwargs : dict keyword arguments for the constructor
Returns
new_instance: BehaviorDataset a new instance of the same class
850 @classmethod 851 def get_parameters(cls, data_type: str, annotation_type: str) -> List: 852 """Get parameters necessary for initialization. 853 854 Parameters 855 ---------- 856 data_type : str 857 the data type 858 annotation_type : str 859 the annotation type 860 861 Returns 862 ------- 863 parameters : list 864 a list of parameters 865 866 """ 867 input_features = options.input_stores[data_type].get_parameters() 868 annotation_features = options.annotation_stores[ 869 annotation_type 870 ].get_parameters() 871 self_features = inspect.getfullargspec(cls.__init__).args 872 return self_features + input_features + annotation_features
Get parameters necessary for initialization.
Parameters
data_type : str the data type annotation_type : str the annotation type
Returns
parameters : list a list of parameters
874 @staticmethod 875 def data_types() -> List: 876 """List available data types. 877 878 Returns 879 ------- 880 data_types : list 881 available data types 882 883 """ 884 return list(options.input_stores.keys())
List available data types.
Returns
data_types : list available data types
886 @staticmethod 887 def annotation_types() -> List: 888 """List available annotation types. 889 890 Returns 891 ------- 892 annotation_types : list 893 available annotation types 894 895 """ 896 return list(options.annotation_stores.keys())
List available annotation types.
Returns
annotation_types : list available annotation types
952 def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None: 953 """Set the parameters that change the subset that is returned at `__getitem__`. 954 955 Parameters 956 ---------- 957 unlabeled : bool 958 a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and 959 all if `None` 960 tag : int 961 if not `None`, only samples with this meta tag will be returned 962 963 """ 964 if unlabeled != self.return_unlabeled: 965 self.annotation_indices = self.annotation_store.get_indices(unlabeled) 966 self.return_unlabeled = unlabeled 967 if tag != self.tag: 968 self.input_indices = self.input_store.get_indices(tag) 969 self.tag = tag 970 self.indices = [x for x in self.annotation_indices if x in self.input_indices]
Set the parameters that change the subset that is returned at __getitem__
.
Parameters
unlabeled : bool
a pseudolabeling parameter; return only unlabeled samples if True
, only labeled if False
and
all if None
tag : int
if not None
, only samples with this meta tag will be returned
1413 def get_intervals(self) -> Tuple[dict, Optional[list]]: 1414 """Get a list of intervals covered by the dataset in the original coordinates. 1415 1416 Returns 1417 ------- 1418 intervals : dict 1419 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1420 values are lists of the intervals in `[start, end]` format 1421 1422 """ 1423 counter = defaultdict(lambda: {}) 1424 coordinates = self.input_store.get_original_coordinates() 1425 for coords in coordinates: 1426 l = self.input_store.get_clip_length_from_coords(coords) 1427 video_name = self.input_store.get_video_id(coords) 1428 clip_id = self.input_store.get_clip_id(coords) 1429 start, end = self.input_store.get_clip_start_end(coords) 1430 if clip_id not in counter[video_name]: 1431 counter[video_name][clip_id] = np.zeros(l) 1432 counter[video_name][clip_id][start:end] = 1 1433 result = {video_name: {} for video_name in counter} 1434 for video_name in counter: 1435 for clip_id in counter[video_name]: 1436 result[video_name][clip_id] = self._get_intervals_from_ind( 1437 np.where(counter[video_name][clip_id])[0] 1438 ) 1439 return result, self.ids
Get a list of intervals covered by the dataset in the original coordinates.
Returns
intervals : dict
a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
values are lists of the intervals in [start, end]
format
1441 def get_unannotated_intervals(self, first_intervals=None) -> Dict: 1442 """Get a list of intervals in the original coordinates where there is no annotation. 1443 1444 Parameters 1445 ---------- 1446 first_intervals : dict 1447 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1448 values are lists of the intervals in `[start, end]` format. If provided, only the intersection with 1449 those intervals will be returned 1450 1451 Returns 1452 ------- 1453 intervals : dict 1454 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1455 values are lists of the intervals in `[start, end]` format 1456 1457 """ 1458 counter_value = 2 1459 if first_intervals is None: 1460 first_intervals = defaultdict(lambda: defaultdict(lambda: [])) 1461 counter_value = 1 1462 counter = defaultdict(lambda: {}) 1463 coordinates = self.input_store.get_original_coordinates() 1464 for i, coords in enumerate(coordinates): 1465 l = self.input_store.get_clip_length_from_coords(coords) 1466 ann = self.annotation_store[i] 1467 if ( 1468 self.annotation_store.annotation_class() 1469 == "nonexclusive_classification" 1470 ): 1471 ann = ann[0, :] 1472 video_name = self.input_store.get_video_id(coords) 1473 clip_id = self.input_store.get_clip_id(coords) 1474 start, end = self.input_store.get_clip_start_end(coords) 1475 if clip_id not in counter[video_name]: 1476 counter[video_name][clip_id] = np.ones(l) 1477 counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int() 1478 result = {video_name: {} for video_name in counter} 1479 for video_name in counter: 1480 for clip_id in counter[video_name]: 1481 for start, end in first_intervals[video_name][clip_id]: 1482 counter[video_name][clip_id][start:end] += 1 1483 result[video_name][clip_id] = self._get_intervals_from_ind( 1484 np.where(counter[video_name][clip_id] == counter_value)[0] 1485 ) 1486 return result
Get a list of intervals in the original coordinates where there is no annotation.
Parameters
first_intervals : dict
a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
values are lists of the intervals in [start, end]
format. If provided, only the intersection with
those intervals will be returned
Returns
intervals : dict
a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
values are lists of the intervals in [start, end]
format
1488 def get_annotated_intervals(self) -> Dict: 1489 """Get a list of intervals in the original coordinates where there is no annotation. 1490 1491 Returns 1492 ------- 1493 intervals : dict 1494 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1495 values are lists of the intervals in `[start, end]` format 1496 1497 """ 1498 if self.annotation_type == "none": 1499 return [] 1500 counter_value = 1 1501 counter = defaultdict(lambda: {}) 1502 coordinates = self.input_store.get_original_coordinates() 1503 for i, coords in enumerate(coordinates): 1504 l = self.input_store.get_clip_length_from_coords(coords) 1505 ann = self.annotation_store[i] 1506 video_name = self.input_store.get_video_id(coords) 1507 clip_id = self.input_store.get_clip_id(coords) 1508 start, end = self.input_store.get_clip_start_end(coords) 1509 if clip_id not in counter[video_name]: 1510 counter[video_name][clip_id] = np.zeros(l) 1511 if ( 1512 self.annotation_store.annotation_class() 1513 == "nonexclusive_classification" 1514 ): 1515 counter[video_name][clip_id][start:end] = ( 1516 torch.sum(ann[:, : end - start] != -100, dim=0) > 0 1517 ).int() 1518 else: 1519 counter[video_name][clip_id][start:end] = ( 1520 ann[: end - start] != -100 1521 ).int() 1522 result = {video_name: {} for video_name in counter} 1523 for video_name in counter: 1524 for clip_id in counter[video_name]: 1525 result[video_name][clip_id] = self._get_intervals_from_ind( 1526 np.where(counter[video_name][clip_id] == counter_value)[0] 1527 ) 1528 return result
Get a list of intervals in the original coordinates where there is no annotation.
Returns
intervals : dict
a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
values are lists of the intervals in [start, end]
format
1530 def get_ids(self) -> Dict: 1531 """Get a dictionary of all clip ids in the dataset. 1532 1533 Returns 1534 ------- 1535 ids : dict 1536 a dictionary where keys are video ids and values are lists of clip ids 1537 1538 """ 1539 coordinates = self.input_store.get_original_coordinates() 1540 video_ids = np.array(self.input_store.get_video_id_order()) 1541 id_set = set(video_ids) 1542 result = {} 1543 for video_id in id_set: 1544 coords = coordinates[video_ids == video_id] 1545 clip_ids = list({self.input_store.get_clip_id(c) for c in coords}) 1546 result[video_id] = clip_ids 1547 return result
Get a dictionary of all clip ids in the dataset.
Returns
ids : dict a dictionary where keys are video ids and values are lists of clip ids
1549 def get_len(self, video_id: str, clip_id: str) -> int: 1550 """Get the length of a specific clip. 1551 1552 Parameters 1553 ---------- 1554 video_id : str 1555 the video id 1556 clip_id : str 1557 the clip id 1558 1559 Returns 1560 ------- 1561 length : int 1562 the length 1563 1564 """ 1565 return self.input_store.get_clip_length(video_id, clip_id)
Get the length of a specific clip.
Parameters
video_id : str the video id clip_id : str the clip id
Returns
length : int the length
1567 def get_confusion_matrix( 1568 self, prediction: torch.Tensor, confusion_type: str = "recall" 1569 ) -> Tuple[ndarray, list]: 1570 """Get a confusion matrix. 1571 1572 Parameters 1573 ---------- 1574 prediction : torch.Tensor 1575 a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)` 1576 confusion_type : {"recall", "precision"} 1577 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 1578 into account, and if `type` is `"precision"`, only false negatives 1579 1580 Returns 1581 ------- 1582 confusion_matrix : np.ndarray 1583 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 1584 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 1585 `N_i` is the number of frames that have the i-th label in the ground truth 1586 classes : list 1587 a list of classes 1588 1589 """ 1590 behaviors_dict = self.annotation_store.behaviors_dict() 1591 num_behaviors = len(behaviors_dict) 1592 confusion_matrix = np.zeros((num_behaviors, num_behaviors)) 1593 if self.annotation_store.annotation_class() == "exclusive_classification": 1594 exclusive = True 1595 confusion_type = None 1596 elif self.annotation_store.annotation_class() == "nonexclusive_classification": 1597 exclusive = False 1598 else: 1599 raise RuntimeError( 1600 f"The {self.annotation_store.annotation_class()} annotation class is not recognized!" 1601 ) 1602 for ann, p in zip(self.annotation_store, prediction): 1603 if exclusive: 1604 class_prediction = torch.max(p, dim=0)[1] 1605 for i in behaviors_dict.keys(): 1606 for j in behaviors_dict.keys(): 1607 confusion_matrix[i, j] += int( 1608 torch.sum(class_prediction[ann == i] == j) 1609 ) 1610 else: 1611 class_prediction = (p > 0.5).int() 1612 for i in behaviors_dict.keys(): 1613 for j in behaviors_dict.keys(): 1614 if confusion_type == "recall": 1615 pred = deepcopy(class_prediction[j]) 1616 if i != j: 1617 pred[ann[j] == 1] = 0 1618 confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1])) 1619 elif confusion_type == "precision": 1620 annotation = deepcopy(ann[j]) 1621 if i != j: 1622 annotation[class_prediction[j] == 1] = 0 1623 confusion_matrix[i, j] += int( 1624 torch.sum(annotation[class_prediction[i] == 1]) 1625 ) 1626 else: 1627 raise ValueError( 1628 f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']" 1629 ) 1630 counter = self.annotation_store.count_classes() 1631 for i in behaviors_dict.keys(): 1632 if counter[i] != 0: 1633 if confusion_type == "recall" or confusion_type is None: 1634 confusion_matrix[i, :] /= counter[i] 1635 else: 1636 confusion_matrix[:, i] /= counter[i] 1637 return confusion_matrix, list(behaviors_dict.values()), confusion_type
Get a confusion matrix.
Parameters
prediction : torch.Tensor
a tensor of predicted class probabilities of shape (#samples, #classes, #frames)
confusion_type : {"recall", "precision"}
for datasets with non-exclusive annotation, if type
is "recall"
, only false positives are taken
into account, and if type
is "precision"
, only false negatives
Returns
confusion_matrix : np.ndarray
a confusion matrix of shape (#classes, #classes)
where A[i, j] = F_ij/N_i
, F_ij
is the number of
frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
N_i
is the number of frames that have the i-th label in the ground truth
classes : list
a list of classes
1639 def annotation_class(self) -> str: 1640 """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon). 1641 1642 Returns 1643 ------- 1644 annotation_class : str 1645 the type of annotation 1646 1647 """ 1648 return self.annotation_store.annotation_class()
Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).
Returns
annotation_class : str the type of annotation
1650 def set_normalization_stats(self, stats: Dict) -> None: 1651 """Set the stats to normalize data at runtime. 1652 1653 Parameters 1654 ---------- 1655 stats : dict 1656 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1657 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1658 1659 """ 1660 self.stats = stats
Set the stats to normalize data at runtime.
Parameters
stats : dict
a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
and values are the statistics in torch
tensors of shape (#features, 1)
1662 def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]: 1663 """Get the minimum and maximum frame numbers for each clip in a video. 1664 1665 Parameters 1666 ---------- 1667 video_id : str 1668 the video id 1669 1670 Returns 1671 ------- 1672 min_frames : dict 1673 a dictionary where keys are clip ids and values are the minimum frame numbers 1674 max_frames : dict 1675 a dictionary where keys are clip ids and values are the maximum frame numbers 1676 1677 """ 1678 coords = self.input_store.get_original_coordinates() 1679 clips = set( 1680 [ 1681 self.input_store.get_clip_id(c) 1682 for c in coords 1683 if self.input_store.get_video_id(c) == video_id 1684 ] 1685 ) 1686 min_frames = {} 1687 max_frames = {} 1688 for clip in clips: 1689 start = self.input_store.get_clip_start(video_id, clip) 1690 end = start + self.input_store.get_clip_length(video_id, clip) 1691 min_frames[clip] = start 1692 max_frames[clip] = end - 1 1693 return min_frames, max_frames
Get the minimum and maximum frame numbers for each clip in a video.
Parameters
video_id : str the video id
Returns
min_frames : dict a dictionary where keys are clip ids and values are the minimum frame numbers max_frames : dict a dictionary where keys are clip ids and values are the maximum frame numbers
1695 def get_normalization_stats(self, skip_keys=None) -> Dict: 1696 """Get mean and standard deviation for each key. 1697 1698 Parameters 1699 ---------- 1700 skip_keys : list, optional 1701 a list of keys to skip 1702 1703 Returns 1704 ------- 1705 stats : dict 1706 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1707 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1708 1709 """ 1710 stats = defaultdict(lambda: {}) 1711 sums = defaultdict(lambda: 0) 1712 if skip_keys is None: 1713 skip_keys = [] 1714 counter = defaultdict(lambda: 0) 1715 for sample in tqdm(self): 1716 for key, value in sample["input"].items(): 1717 key_name = key.split("---")[0] 1718 if key_name not in skip_keys: 1719 sums[key_name] += value[:, value.sum(0) != 0].sum(-1) 1720 counter[key_name] += torch.sum(value.sum(0) != 0) 1721 for key, value in sums.items(): 1722 stats[key]["mean"] = (value / counter[key]).unsqueeze(-1) 1723 sums = defaultdict(lambda: 0) 1724 for sample in tqdm(self): 1725 for key, value in sample["input"].items(): 1726 key_name = key.split("---")[0] 1727 if key_name not in skip_keys: 1728 sums[key_name] += ( 1729 (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2 1730 ).sum(-1) 1731 for key, value in sums.items(): 1732 stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key]) 1733 return stats
Get mean and standard deviation for each key.
Parameters
skip_keys : list, optional a list of keys to skip
Returns
stats : dict
a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
and values are the statistics in torch
tensors of shape (#features, 1)