dlc2action.data.dataset
Behavior dataset (class that manages high-level data interactions)
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Behavior dataset (class that manages high-level data interactions) 8""" 9 10import warnings 11from typing import Dict, Union, Tuple, List, Optional, Any 12 13from numpy import ndarray 14from torch.utils.data import Dataset 15import torch 16from abc import ABC 17import numpy as np 18from copy import copy 19from tqdm import tqdm 20from collections import Counter 21import inspect 22from collections import defaultdict 23from dlc2action.utils import ( 24 apply_threshold_hysteresis, 25 apply_threshold, 26 apply_threshold_max, 27) 28from dlc2action.data.base_store import InputStore, AnnotationStore 29from copy import deepcopy, copy 30import os 31import pickle 32from dlc2action import options 33 34 35class BehaviorDataset(Dataset, ABC): 36 """ 37 A generalized dataset class 38 39 Data and annotation are stored in separate InputStore and AnnotationStore objects; the dataset class 40 manages their interactions. 41 """ 42 43 def __init__( 44 self, 45 data_type: str, 46 annotation_type: str = "none", 47 ssl_transformations: List = None, 48 saved_data_path: str = None, 49 input_store: InputStore = None, 50 annotation_store: AnnotationStore = None, 51 only_load_annotated: bool = False, 52 recompute_annotation: bool = False, 53 # mask: str = None, 54 ids: List = None, 55 **data_parameters, 56 ) -> None: 57 """ 58 Parameters 59 ---------- 60 data_type : str 61 the data type (see available types by running BehaviorDataset.data_types()) 62 annotation_type : str 63 the annotation type (see available types by running BehaviorDataset.annotation_types()) 64 ssl_transformations : list 65 a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple 66 saved_data_path : str 67 the path to a pre-computed pickled dataset 68 input_store : InputStore 69 a pre-computed input store 70 annotation_store : AnnotationStore 71 a precomputed annotation store 72 only_load_annotated : bool 73 if True, the input files that don't have a matching annotation file will be disregarded 74 *data_parameters : dict 75 parameters to initialize the input and annotation stores 76 """ 77 78 mask = None 79 if len(data_parameters) == 0: 80 recompute_annotation = False 81 feature_extraction = data_parameters.get("feature_extraction") 82 if feature_extraction is not None and not issubclass( 83 options.input_stores[data_type], 84 options.feature_extractors[feature_extraction].input_store_class, 85 ): 86 raise ValueError( 87 f"The {feature_extraction} feature extractor does not work with " 88 f"the {data_type} data type, please choose a suclass of " 89 f"{options.feature_extractors[feature_extraction].input_store_class}" 90 ) 91 if ssl_transformations is None: 92 ssl_transformations = [] 93 self.ssl_transformations = ssl_transformations 94 self.input_type = data_type 95 self.annotation_type = annotation_type 96 self.stats = None 97 if mask is not None: 98 with open(mask, "rb") as f: 99 self.mask = pickle.load(f) 100 else: 101 self.mask = None 102 self.ids = ids 103 self.tag = None 104 self.return_unlabeled = None 105 # load saved key objects for annotation and input if they exist 106 input_key_objects, annotation_key_objects = None, None 107 if saved_data_path is not None: 108 if os.path.exists(saved_data_path): 109 with open(saved_data_path, "rb") as f: 110 input_key_objects, annotation_key_objects = pickle.load(f) 111 # if the input or the annotation store need to be created, generate the common video order 112 if len(data_parameters) > 0: 113 input_files = options.input_stores[data_type].get_file_ids( 114 **data_parameters 115 ) 116 annotation_files = options.annotation_stores[annotation_type].get_file_ids( 117 **data_parameters 118 ) 119 if only_load_annotated: 120 data_parameters["video_order"] = [ 121 x for x in input_files if x in annotation_files 122 ] 123 else: 124 data_parameters["video_order"] = input_files 125 if len(data_parameters["video_order"]) == 0: 126 raise RuntimeError( 127 "The length of file list is 0! Please check your data parameters!" 128 ) 129 data_parameters["mask"] = self.mask 130 # load or create the input store 131 ok = False 132 if input_store is not None: 133 self.input_store = input_store 134 ok = True 135 elif input_key_objects is not None: 136 try: 137 self.input_store = self._load_input_store(data_type, input_key_objects) 138 ok = True 139 except: 140 warnings.warn("Loading input store from key objects failed") 141 if not ok: 142 self.input_store = self._get_input_store( 143 data_type, deepcopy(data_parameters) 144 ) 145 # get the objects needed to create the annotation store (like a clip length dictionary) 146 annotation_objects = self.input_store.get_annotation_objects() 147 data_parameters.update(annotation_objects) 148 # load or create the annotation store 149 ok = False 150 if annotation_store is not None: 151 self.annotation_store = annotation_store 152 ok = True 153 elif ( 154 annotation_key_objects is not None 155 and mask is None 156 and not recompute_annotation 157 ): 158 try: 159 self.annotation_store = self._load_annotation_store( 160 annotation_type, annotation_key_objects 161 ) 162 ok = True 163 except: 164 warnings.warn("Loading annotation store from key objects failed") 165 if not ok: 166 self.annotation_store = self._get_annotation_store( 167 annotation_type, deepcopy(data_parameters) 168 ) 169 if ( 170 mask is None 171 and annotation_type != "none" 172 and not recompute_annotation 173 and ( 174 self.annotation_store.get_original_coordinates() 175 != self.input_store.get_original_coordinates() 176 ).any() 177 ): 178 raise RuntimeError( 179 "The clip orders in the annotation store and input store are different!" 180 ) 181 # filter the data based on data parameters 182 # print(f"1 {self.annotation_store.get_original_coordinates().shape=}") 183 # print(f"1 {self.input_store.get_original_coordinates().shape=}") 184 to_remove = self.annotation_store.filtered_indices() 185 if len(to_remove) > 0: 186 print( 187 f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples" 188 ) 189 if len(self.input_store) == len(self.annotation_store): 190 self.input_store.remove(to_remove) 191 self.annotation_store.remove(to_remove) 192 self.input_indices = list(range(len(self.input_store))) 193 self.annotation_indices = list(range(len(self.input_store))) 194 self.indices = list(range(len(self.input_store))) 195 # print(f'{data_parameters["video_order"]=}') 196 # print(f"{self.annotation_store.get_original_coordinates().shape=}") 197 # print(f"{self.input_store.get_original_coordinates().shape=}") 198 # count = 0 199 # for i, (x, y) in enumerate(zip( 200 # self.annotation_store.get_original_coordinates(), 201 # self.input_store.get_original_coordinates(), 202 # )): 203 # if (x != y).any(): 204 # count += 1 205 # print({i}) 206 # print(f"ann: {x}") 207 # print(f"inp: {y}") 208 # print("\n") 209 # if count > 50: 210 # break 211 if annotation_type != "none" and ( 212 self.annotation_store.get_original_coordinates().shape 213 != self.input_store.get_original_coordinates().shape 214 or ( 215 self.annotation_store.get_original_coordinates() 216 != self.input_store.get_original_coordinates() 217 ).any() 218 ): 219 raise RuntimeError( 220 "The clip orders in the annotation store and input store are different!" 221 ) 222 223 def __getitem__(self, item: int) -> Dict: 224 idx = self._get_idx(item) 225 input = deepcopy(self.input_store[idx]) 226 target = self.annotation_store[idx] 227 tag = self.input_store.get_tag(idx) 228 ssl_inputs, ssl_targets = self._get_SSL_targets(input) 229 batch = {"input": input} 230 for name, x in zip( 231 ["target", "ssl_inputs", "ssl_targets", "tag"], 232 [target, ssl_inputs, ssl_targets, tag], 233 ): 234 if x is not None: 235 batch[name] = x 236 batch["index"] = idx 237 if self.stats is not None: 238 for key in batch["input"].keys(): 239 key_name = key.split("---")[0] 240 if key_name in self.stats: 241 batch["input"][key][:, batch["input"][key].sum(0) != 0] = ( 242 (batch["input"][key] - self.stats[key_name]["mean"]) 243 / (self.stats[key_name]["std"] + 1e-7) 244 )[:, batch["input"][key].sum(0) != 0] 245 return batch 246 247 def __len__(self) -> int: 248 return len(self.indices) 249 # if self.annotation_type != "none": 250 # return self.annotation_store.get_len(return_unlabeled=self.return_unlabeled) 251 # else: 252 # return len(self.input_store) 253 254 def get_tags(self) -> List: 255 """ 256 Get a list of all meta tags 257 258 Returns 259 ------- 260 tags: List 261 a list of unique meta tag values 262 """ 263 264 return self.input_store.get_tags() 265 266 def save(self, save_path: str) -> None: 267 """ 268 Save the dictionary 269 270 Parameters 271 ---------- 272 save_path : str 273 the path where the pickled file will be stored 274 """ 275 276 input_obj = self.input_store.key_objects() 277 annotation_obj = self.annotation_store.key_objects() 278 with open(save_path, "wb") as f: 279 pickle.dump((input_obj, annotation_obj), f) 280 281 def to_ram(self) -> None: 282 """ 283 Transfer the dataset to RAM 284 """ 285 286 self.input_store.to_ram() 287 self.annotation_store.to_ram() 288 289 def generate_full_length_gt(self) -> Dict: 290 if self.annotation_class() == "exclusive_classification": 291 gt = torch.zeros((len(self), self.len_segment())) 292 else: 293 gt = torch.zeros( 294 (len(self), len(self.behaviors_dict()), self.len_segment()) 295 ) 296 for i in range(len(self)): 297 gt[i] = self.annotation_store[i] 298 return self.generate_full_length_prediction(gt) 299 300 def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict: 301 """ 302 Map predictions for the equal-length pieces to predictions for the original data 303 304 Probabilities are averaged over predictions on overlapping intervals. 305 306 Parameters 307 ---------- 308 predicted: torch.Tensor 309 a tensor of predicted probabilities of shape `(N, #classes, #frames)` 310 311 Returns 312 ------- 313 full_length_prediction : dict 314 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are 315 averaged probability tensors 316 """ 317 318 result = defaultdict(lambda: {}) 319 counter = defaultdict(lambda: {}) 320 coordinates = self.input_store.get_original_coordinates() 321 for coords, prediction in zip(coordinates, predicted): 322 l = self.input_store.get_clip_length_from_coords(coords) 323 video_name = self.input_store.get_video_id(coords) 324 clip_id = self.input_store.get_clip_id(coords) 325 start, end = self.input_store.get_clip_start_end(coords) 326 if clip_id not in result[video_name].keys(): 327 result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 328 counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 329 result[video_name][clip_id][..., start:end] += ( 330 prediction.squeeze()[..., : end - start].detach().cpu() 331 ) 332 counter[video_name][clip_id][..., start:end] += 1 333 for video_name in result: 334 for clip_id in result[video_name]: 335 result[video_name][clip_id] /= counter[video_name][clip_id] 336 result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100 337 result = dict(result) 338 return result 339 340 def find_valleys( 341 self, 342 predicted: Union[torch.Tensor, Dict], 343 threshold: float = 0.5, 344 min_frames: int = 0, 345 visibility_min_score: float = 0, 346 visibility_min_frac: float = 0, 347 main_class: int = 1, 348 low: bool = True, 349 predicted_error: torch.Tensor = None, 350 error_threshold: float = 0.5, 351 hysteresis: bool = False, 352 threshold_diff: float = None, 353 min_frames_error: int = None, 354 smooth_interval: int = 1, 355 cut_annotated: bool = False, 356 ) -> Dict: 357 """ 358 Find the intervals where the probability of a certain class is below or above a certain hard_threshold 359 360 Parameters 361 ---------- 362 predicted : torch.Tensor | dict 363 either a tensor of predictions for the data prompts or the output of 364 `BehaviorDataset.generate_full_length_prediction` 365 threshold : float, default 0.5 366 the main hard_threshold 367 min_frames : int, default 0 368 the minimum length of the intervals 369 visibility_min_score : float, default 0 370 the minimum visibility score in the intervals 371 visibility_min_frac : float, default 0 372 fraction of the interval that has to have the visibility score larger than visibility_score_thr 373 main_class : int, default 1 374 the index of the class the function is inspecting 375 low : bool, default True 376 if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above 377 predicted_error : torch.Tensor, optional 378 a tensor of error predictions for the data prompts 379 error_threshold : float, default 0.5 380 maximum possible probability of error at the intervals 381 hysteresis: bool, default False 382 if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff 383 threshold_diff: float, optional 384 the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes 385 min_frames_error: int, optional 386 if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames 387 388 Returns 389 ------- 390 valleys : dict 391 a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals 392 """ 393 394 result = defaultdict(lambda: []) 395 if type(predicted) is not dict: 396 predicted = self.generate_full_length_prediction(predicted) 397 if predicted_error is not None: 398 predicted_error = self.generate_full_length_prediction(predicted_error) 399 elif min_frames_error is not None and min_frames_error != 0: 400 # warnings.warn( 401 # f"The min_frames_error parameter is set to {min_frames_error} but no error prediction " 402 # f"is given! Setting min_frames_error to 0." 403 # ) 404 min_frames_error = 0 405 if low and hysteresis and threshold_diff is None: 406 raise ValueError( 407 "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff." 408 ) 409 if cut_annotated: 410 masked_intervals_dict = self.get_annotated_intervals() 411 else: 412 masked_intervals_dict = None 413 print("Valleys found:") 414 for v_id in predicted: 415 for clip_id in predicted[v_id].keys(): 416 if predicted_error is not None: 417 error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold 418 if min_frames_error is not None: 419 output, indices, counts = torch.unique_consecutive( 420 error_mask, return_inverse=True, return_counts=True 421 ) 422 wrong_indices = torch.where( 423 output * (counts < min_frames_error) 424 )[0] 425 if len(wrong_indices) > 0: 426 for i in wrong_indices: 427 error_mask[indices == i] = False 428 else: 429 error_mask = None 430 if masked_intervals_dict is not None: 431 masked_intervals = masked_intervals_dict[v_id][clip_id] 432 else: 433 masked_intervals = None 434 if not hysteresis: 435 res_indices_start, res_indices_end = apply_threshold( 436 predicted[v_id][clip_id][main_class, :], 437 threshold, 438 low, 439 error_mask, 440 min_frames, 441 smooth_interval, 442 masked_intervals, 443 ) 444 elif threshold_diff is not None: 445 if low: 446 soft_threshold = threshold + threshold_diff 447 else: 448 soft_threshold = threshold - threshold_diff 449 res_indices_start, res_indices_end = apply_threshold_hysteresis( 450 predicted[v_id][clip_id][main_class, :], 451 soft_threshold, 452 threshold, 453 low, 454 error_mask, 455 min_frames, 456 smooth_interval, 457 masked_intervals, 458 ) 459 else: 460 res_indices_start, res_indices_end = apply_threshold_max( 461 predicted[v_id][clip_id], 462 threshold, 463 main_class, 464 error_mask, 465 min_frames, 466 smooth_interval, 467 masked_intervals, 468 ) 469 start = self.input_store.get_clip_start(v_id, clip_id) 470 result[v_id] += [ 471 [i + start, j + start, clip_id] 472 for i, j in zip(res_indices_start, res_indices_end) 473 if self.input_store.get_visibility( 474 v_id, clip_id, i, j, visibility_min_score 475 ) 476 > visibility_min_frac 477 ] 478 result[v_id] = sorted(result[v_id]) 479 print(f" {v_id}: {len(result[v_id])}") 480 return dict(result) 481 482 def valleys_union(self, valleys_list) -> Dict: 483 """ 484 Find the intersection of two valleys dictionaries 485 486 Parameters 487 ---------- 488 valleys_list : list 489 a list of valleys dictionaries 490 491 Returns 492 ------- 493 intersection : dict 494 a new valleys dictionary with the intersection of the input intervals 495 """ 496 497 valleys_list = [x for x in valleys_list if x is not None] 498 if len(valleys_list) == 1: 499 return valleys_list[0] 500 elif len(valleys_list) == 0: 501 return {} 502 union = {} 503 keys_list = [set(valleys.keys()) for valleys in valleys_list] 504 keys = set.union(*keys_list) 505 for v_id in keys: 506 res = [] 507 clips_list = [ 508 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 509 ] 510 clips = set.union(*clips_list) 511 for clip_id in clips: 512 clip_intervals = [ 513 x 514 for valleys in valleys_list 515 for x in valleys[v_id] 516 if x[-1] == clip_id 517 ] 518 v_len = self.input_store.get_clip_length(v_id, clip_id) 519 arr = torch.zeros(v_len) 520 for start, end, _ in clip_intervals: 521 arr[start:end] += 1 522 output, indices, counts = torch.unique_consecutive( 523 arr > 0, return_inverse=True, return_counts=True 524 ) 525 long_indices = torch.where(output)[0] 526 res += [ 527 ( 528 (indices == i).nonzero(as_tuple=True)[0][0].item(), 529 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 530 clip_id, 531 ) 532 for i in long_indices 533 ] 534 union[v_id] = res 535 return union 536 537 def valleys_intersection(self, valleys_list) -> Dict: 538 """ 539 Find the intersection of two valleys dictionaries 540 541 Parameters 542 ---------- 543 valleys_list : list 544 a list of valleys dictionaries 545 546 Returns 547 ------- 548 intersection : dict 549 a new valleys dictionary with the intersection of the input intervals 550 """ 551 552 valleys_list = [x for x in valleys_list if x is not None] 553 if len(valleys_list) == 1: 554 return valleys_list[0] 555 elif len(valleys_list) == 0: 556 return {} 557 intersection = {} 558 keys_list = [set(valleys.keys()) for valleys in valleys_list] 559 keys = set.intersection(*keys_list) 560 for v_id in keys: 561 res = [] 562 clips_list = [ 563 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 564 ] 565 clips = set.intersection(*clips_list) 566 for clip_id in clips: 567 clip_intervals = [ 568 x 569 for valleys in valleys_list 570 for x in valleys[v_id] 571 if x[-1] == clip_id 572 ] 573 v_len = self.input_store.get_clip_length(v_id, clip_id) 574 arr = torch.zeros(v_len) 575 for start, end, _ in clip_intervals: 576 arr[start:end] += 1 577 output, indices, counts = torch.unique_consecutive( 578 arr, return_inverse=True, return_counts=True 579 ) 580 long_indices = torch.where(output == 2)[0] 581 res += [ 582 ( 583 (indices == i).nonzero(as_tuple=True)[0][0].item(), 584 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 585 clip_id, 586 ) 587 for i in long_indices 588 ] 589 intersection[v_id] = res 590 return intersection 591 592 def partition_train_test_val( 593 self, 594 use_test: float = 0, 595 split_path: str = None, 596 method: str = "random", 597 val_frac: float = 0, 598 test_frac: float = 0, 599 save_split: bool = False, 600 normalize: bool = False, 601 skip_normalization_keys: List = None, 602 stats: Dict = None, 603 ) -> Tuple: 604 """ 605 Partition the dataset into three new datasets 606 607 Parameters 608 ---------- 609 use_test : float, default 0 610 The fraction of the test dataset to be used in training without labels 611 split_path : str, optional 612 The path to load the split information from (if `'file'` method is used) and to save it to 613 (if `'save_split'` is `True`) 614 method : {'random', 'random:test-from-name', 'random:test-from-name:{name}', 615 'val-from-name:{val_name}:test-from-name:{test_name}', 616 'random:equalize:segments', 'random:equalize:videos', 617 'folders', 'time', 'time:strict', 'file'} 618 The partitioning method: 619 - `'random'`: sort videos into subsets randomly, 620 - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation 621 subsets randomly and create 622 the test subset from the video ids that start with a speific substring (`'test'` by default, or `name` 623 if provided), 624 - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but 625 making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain 626 occurences of the class get into the validation subset and `0.8 * test_frac` get into the test subset; 627 this in ensured for all classes in order of increasing number of occurences until the validation and 628 test subsets are full 629 - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test 630 subsets from the video ids that start with specific substrings (`val_name` for validation 631 and `test_name` for test) and sort all other videos into the training subset 632 - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets, 633 - `'time'`: split each video into training, validation and test subsequences, 634 - `'time:strict'`: split each video into validation, test and training subsequences 635 and throw out the last segments in validation and test (to get rid of overlaps), 636 - `'file'`: split according to a split file. 637 val_frac : float, default 0 638 The fraction of the dataset to be used in validation 639 test_frac : float, default 0 640 The fraction of the dataset to be used in test 641 save_split : bool, default False 642 Save a split file if True 643 644 Returns 645 ------- 646 train_dataset : BehaviorDataset 647 train dataset 648 649 val_dataset : BehaviorDataset 650 validation dataset 651 652 test_dataset : BehaviorDataset 653 test dataset 654 """ 655 656 train_indices, test_indices, val_indices = self._partition_indices( 657 split_path=split_path, 658 method=method, 659 val_frac=val_frac, 660 test_frac=test_frac, 661 save_split=save_split, 662 ) 663 val_dataset = self._create_new_dataset(val_indices) 664 test_dataset = self._create_new_dataset(test_indices) 665 train_dataset = self._create_new_dataset( 666 train_indices, ssl_indices=test_indices[: int(len(test_indices) * use_test)] 667 ) 668 print("Number of samples:") 669 print(f" validation:") 670 print(f" {val_dataset.count_classes()}") 671 print(f" training:") 672 print(f" {train_dataset.count_classes()}") 673 print(f" test:") 674 print(f" {test_dataset.count_classes()}") 675 if normalize: 676 if stats is None: 677 print("Computing normalization statistics...") 678 stats = train_dataset.get_normalization_stats(skip_normalization_keys) 679 else: 680 print("Setting loaded normalization statistics...") 681 train_dataset.set_normalization_stats(stats) 682 val_dataset.set_normalization_stats(stats) 683 test_dataset.set_normalization_stats(stats) 684 return train_dataset, test_dataset, val_dataset 685 686 def class_weights(self, proportional=False) -> List: 687 """ 688 Calculate class weights in inverse proportion to number of samples 689 Returns 690 ------- 691 weights: list 692 a list of class weights 693 """ 694 695 items = sorted( 696 [ 697 (k, v) 698 for k, v in self.annotation_store.count_classes().items() 699 if k != -100 700 ] 701 ) 702 if self.annotation_store.annotation_class() == "exclusive_classification": 703 if not proportional: 704 numerator = len(self.annotation_store) 705 else: 706 numerator = max([x[1] for x in items]) 707 weights = [numerator / (v + 1e-7) for _, v in items] 708 else: 709 items_zero = sorted( 710 [ 711 (k, v) 712 for k, v in self.annotation_store.count_classes(zeros=True).items() 713 if k != -100 714 ] 715 ) 716 if not proportional: 717 numerators = defaultdict(lambda: len(self.annotation_store)) 718 else: 719 numerators = { 720 item_one[0]: max(item_one[1], item_zero[1]) 721 for item_one, item_zero in zip(items, items_zero) 722 } 723 weights = {} 724 weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero] 725 weights[1] = [numerators[k] / (v + 1e-7) for k, v in items] 726 return weights 727 728 def boundary_class_weight(self): 729 if self.annotation_type != "none": 730 f = self.annotation_store.data.flatten() 731 _, inv = torch.unique_consecutive(f, return_inverse=True) 732 boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape( 733 self.annotation_store.data.shape 734 ) 735 boundary[..., 0] = 0 736 cnt = Counter(boundary.flatten().numpy()) 737 return cnt[1] / cnt[0] 738 else: 739 return 0 740 741 def count_classes(self, bouts: bool = False) -> Dict: 742 """ 743 Get a class counter dictionary 744 745 Parameters 746 ---------- 747 bouts : bool, default False 748 if `True`, instead of frame counts segment counts are returned 749 750 Returns 751 ------- 752 count_dictionary : dict 753 a dictionary with class indices as keys and frame or bout counts as values 754 """ 755 756 return self.annotation_store.count_classes(bouts=bouts) 757 758 def behaviors_dict(self) -> Dict: 759 """ 760 Get a behavior dictionary 761 762 Returns 763 ------- 764 dict 765 behavior dictionary 766 """ 767 768 return self.annotation_store.behaviors_dict() 769 770 def bodyparts_order(self) -> List: 771 try: 772 return self.input_store.get_bodyparts() 773 except: 774 raise RuntimeError( 775 f"The {self.input_type} input store does not have bodyparts implemented!" 776 ) 777 778 def features_shape(self) -> Dict: 779 """ 780 Get the shapes of the input features 781 782 Returns 783 ------- 784 shapes : Dict 785 a dictionary with the shapes of the features 786 """ 787 788 sample = self.input_store[0] 789 shapes = {k: v.shape for k, v in sample.items()} 790 # for key, value in shapes.items(): 791 # print(f'{key}: {value}') 792 return shapes 793 794 def num_classes(self) -> int: 795 """ 796 Get the number of classes in the data 797 798 Returns 799 ------- 800 num_classes : int 801 the number of classes 802 """ 803 804 return len(self.annotation_store.behaviors_dict()) 805 806 def len_segment(self) -> int: 807 """ 808 Get the segment length in the data 809 810 Returns 811 ------- 812 len_segment : int 813 the segment length 814 """ 815 816 sample = self.input_store[0] 817 key = list(sample.keys())[0] 818 return sample[key].shape[-1] 819 820 def set_ssl_transformations(self, ssl_transformations: List) -> None: 821 """ 822 Set new SSL transformations 823 824 Parameters 825 ---------- 826 ssl_transformations : list 827 a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets 828 lists 829 """ 830 831 self.ssl_transformations = ssl_transformations 832 833 @classmethod 834 def new(cls, *args, **kwargs): 835 """ 836 Create a new object of the same class 837 838 Returns 839 ------- 840 new_instance: BehaviorDataset 841 a new instance of the same class 842 """ 843 844 return cls(*args, **kwargs) 845 846 @classmethod 847 def get_parameters(cls, data_type: str, annotation_type: str) -> List: 848 """ 849 Get parameters necessary for initialization 850 851 Parameters 852 ---------- 853 data_type : str 854 the data type 855 annotation_type : str 856 the annotation type 857 """ 858 859 input_features = options.input_stores[data_type].get_parameters() 860 annotation_features = options.annotation_stores[ 861 annotation_type 862 ].get_parameters() 863 self_features = inspect.getfullargspec(cls.__init__).args 864 return self_features + input_features + annotation_features 865 866 @staticmethod 867 def data_types() -> List: 868 """ 869 List available data types 870 871 Returns 872 ------- 873 data_types : list 874 available data types 875 """ 876 877 return list(options.input_stores.keys()) 878 879 @staticmethod 880 def annotation_types() -> List: 881 """ 882 List available annotation types 883 884 Returns 885 ------- 886 annotation_types : list 887 available annotation types 888 """ 889 890 return list(options.annotation_stores.keys()) 891 892 def _get_SSL_targets(self, input: Dict) -> Tuple[List, List]: 893 """ 894 Get the SSL inputs and targets from a sample dictionary 895 """ 896 897 ssl_inputs = [] 898 ssl_targets = [] 899 for transform in self.ssl_transformations: 900 ssl_input, ssl_target = transform(copy(input)) 901 ssl_inputs.append(ssl_input) 902 ssl_targets.append(ssl_target) 903 return ssl_inputs, ssl_targets 904 905 def _create_new_dataset(self, indices: List, ssl_indices: List = None): 906 """ 907 Create a subsample of the dataset, with samples at ssl_indices losing the annotation 908 """ 909 910 if ssl_indices is None: 911 ssl_indices = [] 912 input_store = self.input_store.create_subsample(indices, ssl_indices) 913 annotation_store = self.annotation_store.create_subsample(indices, ssl_indices) 914 new = self.new( 915 data_type=self.input_type, 916 annotation_type=self.annotation_type, 917 ssl_transformations=self.ssl_transformations, 918 annotation_store=annotation_store, 919 input_store=input_store, 920 ids=list(indices) + list(ssl_indices), 921 recompute_annotation=False, 922 ) 923 return new 924 925 def _load_input_store(self, data_type: str, key_objects: Tuple) -> InputStore: 926 """ 927 Load input store from key objects 928 """ 929 930 input_store = options.input_stores[data_type](key_objects=key_objects) 931 return input_store 932 933 def _load_annotation_store( 934 self, annotation_type: str, key_objects: Tuple 935 ) -> AnnotationStore: 936 """ 937 Load annotation store from key objects 938 """ 939 940 annotation_store = options.annotation_stores[annotation_type]( 941 key_objects=key_objects 942 ) 943 return annotation_store 944 945 def _get_input_store(self, data_type: str, data_parameters: Dict) -> InputStore: 946 """ 947 Create input store from parameters 948 """ 949 950 data_parameters["key_objects"] = None 951 input_store = options.input_stores[data_type](**data_parameters) 952 return input_store 953 954 def _get_annotation_store( 955 self, annotation_type: str, data_parameters: Dict 956 ) -> AnnotationStore: 957 """ 958 Create annotation store from parameters 959 """ 960 961 annotation_store = options.annotation_stores[annotation_type](**data_parameters) 962 return annotation_store 963 964 def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None: 965 """ 966 Set the parameters that change the subset that is returned at `__getitem__` 967 968 Parameters 969 ---------- 970 unlabeled : bool 971 a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and 972 all if `None` 973 tag : int 974 if not `None`, only samples with this meta tag will be returned 975 """ 976 977 if unlabeled != self.return_unlabeled: 978 self.annotation_indices = self.annotation_store.get_indices(unlabeled) 979 self.return_unlabeled = unlabeled 980 if tag != self.tag: 981 self.input_indices = self.input_store.get_indices(tag) 982 self.tag = tag 983 self.indices = [x for x in self.annotation_indices if x in self.input_indices] 984 985 def _get_idx(self, index: int) -> int: 986 """ 987 Get index in full dataset 988 """ 989 990 return self.indices[index] 991 992 # return self.annotation_store.get_idx( 993 # index, return_unlabeled=self.return_unlabeled 994 # ) 995 996 def _partition_indices( 997 self, 998 split_path: str = None, 999 method: str = "random", 1000 val_frac: float = 0, 1001 test_frac: float = 0, 1002 save_split: bool = False, 1003 ) -> Tuple[List, List, List]: 1004 """ 1005 Partition indices into train, validation, test subsets 1006 """ 1007 1008 if self.mask is not None: 1009 val_indices = self.mask["val_ids"] 1010 train_indices = [x for x in range(len(self)) if x not in val_indices] 1011 test_indices = [] 1012 elif method == "random": 1013 videos = np.array(self.input_store.get_video_id_order()) 1014 all_videos = list(set(videos)) 1015 if len(all_videos) == 1: 1016 warnings.warn( 1017 "There is only one video in the dataset, so train/val/test split is done on segments; " 1018 'that might lead to overlaps, please consider using "time" or "time:strict" as the ' 1019 "partitioning method instead" 1020 ) 1021 # Quick fix for single video: the problem with this is that the segments can overlap 1022 # length = int(self.input_store.get_original_coordinates()[-1][1]) # number of segments 1023 length = len(self.input_store.get_original_coordinates()) 1024 val_len = int(val_frac * length) 1025 test_len = int(test_frac * length) 1026 all_indices = np.random.choice(np.arange(length), length, replace=False) 1027 val_indices = all_indices[:val_len] 1028 test_indices = all_indices[val_len : val_len + test_len] 1029 train_indices = all_indices[val_len + test_len :] 1030 coords = self.input_store.get_original_coordinates() 1031 if save_split: 1032 self._save_partition( 1033 coords[train_indices], 1034 coords[val_indices], 1035 coords[test_indices], 1036 split_path, 1037 coords=True, 1038 ) 1039 else: 1040 length = len(all_videos) 1041 val_len = int(val_frac * length) 1042 test_len = int(test_frac * length) 1043 validation = all_videos[:val_len] 1044 test = all_videos[val_len : val_len + test_len] 1045 training = all_videos[val_len + test_len :] 1046 train_indices = np.where(np.isin(videos, training))[0] 1047 val_indices = np.where(np.isin(videos, validation))[0] 1048 test_indices = np.where(np.isin(videos, test))[0] 1049 if save_split: 1050 self._save_partition(training, validation, test, split_path) 1051 elif method.startswith("random:equalize"): 1052 counter = self.count_classes() 1053 counter = sorted(list([(v, k) for k, v in counter.items()])) 1054 classes = [x[1] for x in counter] 1055 indicator = {c: [] for c in classes} 1056 if method.endswith("videos"): 1057 videos = np.array(self.input_store.get_video_id_order()) 1058 all_videos = list(set(videos)) 1059 total_len = len(all_videos) 1060 for video_id in all_videos: 1061 video_coords = np.where(videos == video_id)[0] 1062 ann = torch.cat( 1063 [self.annotation_store[i] for i in video_coords], dim=-1 1064 ) 1065 for c in classes: 1066 if self.annotation_class() == "nonexclusive_classification": 1067 indicator[c].append(torch.sum(ann[c] == 1) > 0) 1068 elif self.annotation_class() == "exclusive_classification": 1069 indicator[c].append(torch.sum(ann == c) > 0) 1070 else: 1071 raise ValueError( 1072 f"The random:equalize partition method is not implemented" 1073 f"for the {self.annotation_class()} annotation class" 1074 ) 1075 elif method.endswith("segments"): 1076 total_len = len(self) 1077 for ann in self.annotation_store: 1078 for c in classes: 1079 if self.annotation_class() == "nonexclusive_classification": 1080 indicator[c].append(torch.sum(ann[c] == 1) > 0) 1081 elif self.annotation_class() == "exclusive_classification": 1082 indicator[c].append(torch.sum(ann == c) > 0) 1083 else: 1084 raise ValueError( 1085 f"The random:equalize partition method is not implemented" 1086 f"for the {self.annotation_class()} annotation class" 1087 ) 1088 else: 1089 values = [] 1090 for v in options.partition_methods.values(): 1091 values += v 1092 raise ValueError( 1093 f"The {method} partition method is not recognized; please choose from {values}" 1094 ) 1095 val_indices = [] 1096 test_indices = [] 1097 for c in classes: 1098 indicator[c] = np.array(indicator[c]) 1099 ind = np.where(indicator[c])[0] 1100 np.random.shuffle(ind) 1101 c_sum = len(ind) 1102 in_val = np.sum(indicator[c][val_indices]) 1103 in_test = np.sum(indicator[c][test_indices]) 1104 while ( 1105 len(val_indices) < val_frac * total_len 1106 and in_val < val_frac * c_sum * 0.8 1107 ): 1108 first, ind = ind[0], ind[1:] 1109 val_indices = list(set(val_indices).union({first})) 1110 in_val = np.sum(indicator[c][val_indices]) 1111 while ( 1112 len(test_indices) < test_frac * total_len 1113 and in_test < test_frac * c_sum * 0.8 1114 ): 1115 first, ind = ind[0], ind[1:] 1116 test_indices = list(set(test_indices).union({first})) 1117 in_test = np.sum(indicator[c][test_indices]) 1118 if len(val_indices) < int(val_frac * total_len): 1119 left_val = int(val_frac * total_len) - len(val_indices) 1120 else: 1121 left_val = 0 1122 if len(test_indices) < int(test_frac * total_len): 1123 left_test = int(test_frac * total_len) - len(test_indices) 1124 else: 1125 left_test = 0 1126 indicator = np.ones(total_len) 1127 indicator[val_indices] = 0 1128 indicator[test_indices] = 0 1129 ind = np.where(indicator)[0] 1130 np.random.shuffle(ind) 1131 val_indices += list(ind[:left_val]) 1132 test_indices += list(ind[left_val : left_val + left_test]) 1133 train_indices = list(ind[left_val + left_test :]) 1134 if save_split: 1135 if method.endswith("segments"): 1136 coords = self.input_store.get_original_coordinates() 1137 self._save_partition( 1138 coords[train_indices], 1139 coords[val_indices], 1140 coords[test_indices], 1141 coords[split_path], 1142 coords=True, 1143 ) 1144 else: 1145 all_videos = np.array(all_videos) 1146 validation = all_videos[val_indices] 1147 test = all_videos[test_indices] 1148 training = all_videos[train_indices] 1149 self._save_partition(training, validation, test, split_path) 1150 elif method.startswith("random:test-from-name"): 1151 split = method.split(":") 1152 if len(split) > 2: 1153 test_name = split[-1] 1154 else: 1155 test_name = "test" 1156 videos = np.array(self.input_store.get_video_id_order()) 1157 all_videos = list(set(videos)) 1158 test = [] 1159 train_videos = [] 1160 for x in all_videos: 1161 if x.startswith(test_name): 1162 test.append(x) 1163 else: 1164 train_videos.append(x) 1165 length = len(train_videos) 1166 val_len = int(val_frac * length) 1167 validation = train_videos[:val_len] 1168 training = train_videos[val_len:] 1169 train_indices = np.where(np.isin(videos, training))[0] 1170 val_indices = np.where(np.isin(videos, validation))[0] 1171 test_indices = np.where(np.isin(videos, test))[0] 1172 if save_split: 1173 self._save_partition(training, validation, test, split_path) 1174 elif method.startswith("val-from-name"): 1175 split = method.split(":") 1176 if split[2] != "test-from-name": 1177 raise ValueError( 1178 f"The {method} partition method is not recognized, please choose from {options.partition_methods}" 1179 ) 1180 val_name = split[1] 1181 test_name = split[-1] 1182 videos = np.array(self.input_store.get_video_id_order()) 1183 all_videos = list(set(videos)) 1184 test = [] 1185 validation = [] 1186 training = [] 1187 for x in all_videos: 1188 if x.startswith(test_name): 1189 test.append(x) 1190 elif x.startswith(val_name): 1191 validation.append(x) 1192 else: 1193 training.append(x) 1194 train_indices = np.where(np.isin(videos, training))[0] 1195 val_indices = np.where(np.isin(videos, validation))[0] 1196 test_indices = np.where(np.isin(videos, test))[0] 1197 elif method == "folders": 1198 folders = np.array(self.input_store.get_folder_order()) 1199 videos = np.array(self.input_store.get_video_id_order()) 1200 train_indices = np.where(np.isin(folders, ["training", "train"]))[0] 1201 if np.sum(np.isin(folders, ["validation", "val"])) > 0: 1202 val_indices = np.where(np.isin(folders, ["validation", "val"]))[0] 1203 else: 1204 train_videos = list(set(videos[train_indices])) 1205 val_len = int(val_frac * len(train_videos)) 1206 validation = train_videos[:val_len] 1207 training = train_videos[val_len:] 1208 train_indices = np.where(np.isin(videos, training))[0] 1209 val_indices = np.where(np.isin(videos, validation))[0] 1210 test_indices = np.where(folders == "test")[0] 1211 if save_split: 1212 self._save_partition( 1213 list(set(videos[train_indices])), 1214 list(set(videos[val_indices])), 1215 list(set(videos[test_indices])), 1216 split_path, 1217 ) 1218 elif method.startswith("leave-one-out"): 1219 n = int(method.split(":")[-1]) 1220 videos = np.array(self.input_store.get_video_id_order()) 1221 all_videos = sorted(list(set(videos))) 1222 validation = [all_videos.pop(n)] 1223 training = all_videos 1224 train_indices = np.where(np.isin(videos, training))[0] 1225 val_indices = np.where(np.isin(videos, validation))[0] 1226 test_indices = np.array([]) 1227 elif method.startswith("time"): 1228 if method.endswith("strict"): 1229 len_segment = self.len_segment() 1230 # TODO: change this 1231 step = self.input_store.step 1232 num_removed = len_segment // step 1233 else: 1234 num_removed = 0 1235 videos = np.array(self.input_store.get_video_id_order()) 1236 all_videos = set(videos) 1237 train_indices = [] 1238 val_indices = [] 1239 test_indices = [] 1240 start = 0 1241 if len(method.split(":")) > 1 and method.split(":")[1] == "start-from": 1242 start = float(method.split(":")[2]) 1243 for video_id in all_videos: 1244 video_indices = np.where(videos == video_id)[0] 1245 val_len = int(val_frac * len(video_indices)) 1246 test_len = int(test_frac * len(video_indices)) 1247 start_pos = int(start * len(video_indices)) 1248 all_ind = np.ones(len(video_indices)) 1249 val_indices += list(video_indices[start_pos : start_pos + val_len]) 1250 all_ind[start_pos : start_pos + val_len] = 0 1251 if start_pos + val_len > len(video_indices): 1252 p = start_pos + val_len - len(video_indices) 1253 val_indices += list(video_indices[:p]) 1254 all_ind[:p] = 0 1255 else: 1256 p = start_pos + val_len 1257 test_indices += list(video_indices[p : p + test_len]) 1258 all_ind[p : p + test_len] = 0 1259 if p + test_len > len(video_indices): 1260 p = test_len + p - len(video_indices) 1261 test_indices += list(video_indices[:p]) 1262 all_ind[:p] = 0 1263 train_indices += list(video_indices[all_ind > 0]) 1264 for _ in range(num_removed): 1265 if len(val_indices) > 0: 1266 val_indices.pop(-1) 1267 if len(test_indices) > 0: 1268 test_indices.pop(-1) 1269 if start > 0 and len(train_indices) > 0: 1270 train_indices.pop(-1) 1271 elif method == "file": 1272 if split_path is None: 1273 raise ValueError( 1274 'You need to either set split_path or change partition method ("file" requires a file)' 1275 ) 1276 with open(split_path) as f: 1277 train_line = f.readline() 1278 line = f.readline() 1279 while not line.startswith("Validation") and not line.startswith( 1280 "Validataion" 1281 ): 1282 line = f.readline() 1283 if line.startswith("Validation"): 1284 validation = [] 1285 test = [] 1286 while True: 1287 line = f.readline() 1288 if line.startswith("Test") or len(line) == 0: 1289 break 1290 validation.append(line.rstrip()) 1291 while True: 1292 line = f.readline() 1293 if len(line) == 0: 1294 break 1295 test.append(line.rstrip()) 1296 type = train_line[9:-2] 1297 else: 1298 line = f.readline() 1299 validation = line.rstrip(",\n ").split(", ") 1300 test = [] 1301 type = "videos" 1302 if type == "videos": 1303 videos = np.array(self.input_store.get_video_id_order()) 1304 val_indices = np.where(np.isin(videos, validation))[0] 1305 test_indices = np.where(np.isin(videos, test))[0] 1306 elif type == "coords": 1307 coords = self.input_store.get_original_coordinates() 1308 video_ids = self.input_store.get_video_id_order() 1309 clip_ids = [self.input_store.get_clip_id(coord) for coord in coords] 1310 starts, ends = zip( 1311 *[self.input_store.get_clip_start_end(coord) for coord in coords] 1312 ) 1313 coords = np.array( 1314 [ 1315 f"{video_id}---{clip_id}---{start}-{end}" 1316 for video_id, clip_id, start, end in zip( 1317 video_ids, clip_ids, starts, ends 1318 ) 1319 ] 1320 ) 1321 val_indices = np.where(np.isin(coords, validation))[0] 1322 test_indices = np.where(np.isin(coords, test))[0] 1323 else: 1324 raise ValueError("The split path has unrecognized format!") 1325 all = np.ones(len(self)) 1326 all[val_indices] = 0 1327 all[test_indices] = 0 1328 train_indices = np.where(all)[0] 1329 else: 1330 raise ValueError( 1331 f"The {method} partition is not recognized, please choose from {options.partition_methods}" 1332 ) 1333 return sorted(train_indices), sorted(test_indices), sorted(val_indices) 1334 1335 def _save_partition( 1336 self, 1337 training: List, 1338 validation: List, 1339 test: List, 1340 split_path: str, 1341 coords: bool = False, 1342 ) -> None: 1343 """ 1344 Save a split file 1345 """ 1346 1347 if coords: 1348 name = "coords" 1349 training_coords = [] 1350 val_coords = [] 1351 test_coords = [] 1352 for coord in training: 1353 video_id = self.input_store.get_video_id(coord) 1354 clip_id = self.input_store.get_clip_id(coord) 1355 start, end = self.input_store.get_clip_start_end(coord) 1356 training_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1357 for coord in validation: 1358 video_id = self.input_store.get_video_id(coord) 1359 clip_id = self.input_store.get_clip_id(coord) 1360 start, end = self.input_store.get_clip_start_end(coord) 1361 val_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1362 for coord in test: 1363 video_id = self.input_store.get_video_id(coord) 1364 clip_id = self.input_store.get_clip_id(coord) 1365 start, end = self.input_store.get_clip_start_end(coord) 1366 test_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1367 training, validation, test = training_coords, val_coords, test_coords 1368 else: 1369 name = "videos" 1370 if split_path is not None: 1371 with open(split_path, "w") as f: 1372 f.write(f"Training {name}:\n") 1373 for x in training: 1374 f.write(x + "\n") 1375 f.write(f"Validation {name}:\n") 1376 for x in validation: 1377 f.write(x + "\n") 1378 f.write(f"Test {name}:\n") 1379 for x in test: 1380 f.write(x + "\n") 1381 1382 def _get_intervals_from_ind(self, frame_indices: np.ndarray): 1383 """ 1384 Get a list of intervals from a list of frame indices 1385 1386 Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`. 1387 1388 Parameters 1389 ---------- 1390 frame_indices : np.ndarray 1391 a list of frame indices 1392 1393 Returns 1394 ------- 1395 intervals : list 1396 a list of interval boundaries 1397 """ 1398 1399 masked_intervals = [] 1400 breaks = np.where(np.diff(frame_indices) != 1)[0] 1401 if len(frame_indices) > 0: 1402 start = frame_indices[0] 1403 for k in breaks: 1404 masked_intervals.append([start, frame_indices[k] + 1]) 1405 start = frame_indices[k + 1] 1406 masked_intervals.append([start, frame_indices[-1] + 1]) 1407 return masked_intervals 1408 1409 def get_intervals(self) -> Tuple[dict, Optional[list]]: 1410 """ 1411 Get a list of intervals covered by the dataset in the original coordinates 1412 1413 Returns 1414 ------- 1415 intervals : dict 1416 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1417 values are lists of the intervals in `[start, end]` format 1418 """ 1419 1420 counter = defaultdict(lambda: {}) 1421 coordinates = self.input_store.get_original_coordinates() 1422 for coords in coordinates: 1423 l = self.input_store.get_clip_length_from_coords(coords) 1424 video_name = self.input_store.get_video_id(coords) 1425 clip_id = self.input_store.get_clip_id(coords) 1426 start, end = self.input_store.get_clip_start_end(coords) 1427 if clip_id not in counter[video_name]: 1428 counter[video_name][clip_id] = np.zeros(l) 1429 counter[video_name][clip_id][start:end] = 1 1430 result = {video_name: {} for video_name in counter} 1431 for video_name in counter: 1432 for clip_id in counter[video_name]: 1433 result[video_name][clip_id] = self._get_intervals_from_ind( 1434 np.where(counter[video_name][clip_id])[0] 1435 ) 1436 return result, self.ids 1437 1438 def get_unannotated_intervals(self, first_intervals=None) -> Dict: 1439 """ 1440 Get a list of intervals in the original coordinates where there is no annotation 1441 1442 Returns 1443 ------- 1444 intervals : dict 1445 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1446 values are lists of the intervals in `[start, end]` format 1447 """ 1448 1449 counter_value = 2 1450 if first_intervals is None: 1451 first_intervals = defaultdict(lambda: defaultdict(lambda: [])) 1452 counter_value = 1 1453 counter = defaultdict(lambda: {}) 1454 coordinates = self.input_store.get_original_coordinates() 1455 for i, coords in enumerate(coordinates): 1456 l = self.input_store.get_clip_length_from_coords(coords) 1457 ann = self.annotation_store[i] 1458 if ( 1459 self.annotation_store.annotation_class() 1460 == "nonexclusive_classification" 1461 ): 1462 ann = ann[0, :] 1463 video_name = self.input_store.get_video_id(coords) 1464 clip_id = self.input_store.get_clip_id(coords) 1465 start, end = self.input_store.get_clip_start_end(coords) 1466 if clip_id not in counter[video_name]: 1467 counter[video_name][clip_id] = np.ones(l) 1468 counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int() 1469 result = {video_name: {} for video_name in counter} 1470 for video_name in counter: 1471 for clip_id in counter[video_name]: 1472 for start, end in first_intervals[video_name][clip_id]: 1473 counter[video_name][clip_id][start:end] += 1 1474 result[video_name][clip_id] = self._get_intervals_from_ind( 1475 np.where(counter[video_name][clip_id] == counter_value)[0] 1476 ) 1477 return result 1478 1479 def get_annotated_intervals(self) -> Dict: 1480 """ 1481 Get a list of intervals in the original coordinates where there is no annotation 1482 1483 Returns 1484 ------- 1485 intervals : dict 1486 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1487 values are lists of the intervals in `[start, end]` format 1488 """ 1489 1490 if self.annotation_type == "none": 1491 return [] 1492 counter_value = 1 1493 counter = defaultdict(lambda: {}) 1494 coordinates = self.input_store.get_original_coordinates() 1495 for i, coords in enumerate(coordinates): 1496 l = self.input_store.get_clip_length_from_coords(coords) 1497 ann = self.annotation_store[i] 1498 video_name = self.input_store.get_video_id(coords) 1499 clip_id = self.input_store.get_clip_id(coords) 1500 start, end = self.input_store.get_clip_start_end(coords) 1501 if clip_id not in counter[video_name]: 1502 counter[video_name][clip_id] = np.zeros(l) 1503 if ( 1504 self.annotation_store.annotation_class() 1505 == "nonexclusive_classification" 1506 ): 1507 counter[video_name][clip_id][start:end] = ( 1508 torch.sum(ann[:, : end - start] != -100, dim=0) > 0 1509 ).int() 1510 else: 1511 counter[video_name][clip_id][start:end] = ( 1512 ann[: end - start] != -100 1513 ).int() 1514 result = {video_name: {} for video_name in counter} 1515 for video_name in counter: 1516 for clip_id in counter[video_name]: 1517 result[video_name][clip_id] = self._get_intervals_from_ind( 1518 np.where(counter[video_name][clip_id] == counter_value)[0] 1519 ) 1520 return result 1521 1522 def get_ids(self) -> Dict: 1523 """ 1524 Get a dictionary of all clip ids in the dataset 1525 1526 Returns 1527 ------- 1528 ids : dict 1529 a dictionary where keys are video ids and values are lists of clip ids 1530 """ 1531 1532 coordinates = self.input_store.get_original_coordinates() 1533 video_ids = np.array(self.input_store.get_video_id_order()) 1534 id_set = set(video_ids) 1535 result = {} 1536 for video_id in id_set: 1537 coords = coordinates[video_ids == video_id] 1538 clip_ids = list({self.input_store.get_clip_id(c) for c in coords}) 1539 result[video_id] = clip_ids 1540 return result 1541 1542 def get_len(self, video_id: str, clip_id: str) -> int: 1543 """ 1544 Get the length of a specific clip 1545 1546 Parameters 1547 ---------- 1548 video_id : str 1549 the video id 1550 clip_id : str 1551 the clip id 1552 1553 Returns 1554 ------- 1555 length : int 1556 the length 1557 """ 1558 1559 return self.input_store.get_clip_length(video_id, clip_id) 1560 1561 def get_confusion_matrix( 1562 self, prediction: torch.Tensor, confusion_type: str = "recall" 1563 ) -> Tuple[ndarray, list]: 1564 """ 1565 Get a confusion matrix 1566 1567 Parameters 1568 ---------- 1569 prediction : torch.Tensor 1570 a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)` 1571 confusion_type : {"recall", "precision"} 1572 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 1573 into account, and if `type` is `"precision"`, only false negatives 1574 1575 Returns 1576 ------- 1577 confusion_matrix : np.ndarray 1578 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 1579 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 1580 `N_i` is the number of frames that have the i-th label in the ground truth 1581 classes : list 1582 a list of classes 1583 """ 1584 1585 behaviors_dict = self.annotation_store.behaviors_dict() 1586 num_behaviors = len(behaviors_dict) 1587 confusion_matrix = np.zeros((num_behaviors, num_behaviors)) 1588 if self.annotation_store.annotation_class() == "exclusive_classification": 1589 exclusive = True 1590 confusion_type = None 1591 elif self.annotation_store.annotation_class() == "nonexclusive_classification": 1592 exclusive = False 1593 else: 1594 raise RuntimeError( 1595 f"The {self.annotation_store.annotation_class()} annotation class is not recognized!" 1596 ) 1597 for ann, p in zip(self.annotation_store, prediction): 1598 if exclusive: 1599 class_prediction = torch.max(p, dim=0)[1] 1600 for i in behaviors_dict.keys(): 1601 for j in behaviors_dict.keys(): 1602 confusion_matrix[i, j] += int( 1603 torch.sum(class_prediction[ann == i] == j) 1604 ) 1605 else: 1606 class_prediction = (p > 0.5).int() 1607 for i in behaviors_dict.keys(): 1608 for j in behaviors_dict.keys(): 1609 if confusion_type == "recall": 1610 pred = deepcopy(class_prediction[j]) 1611 if i != j: 1612 pred[ann[j] == 1] = 0 1613 confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1])) 1614 elif confusion_type == "precision": 1615 annotation = deepcopy(ann[j]) 1616 if i != j: 1617 annotation[class_prediction[j] == 1] = 0 1618 confusion_matrix[i, j] += int( 1619 torch.sum(annotation[class_prediction[i] == 1]) 1620 ) 1621 else: 1622 raise ValueError( 1623 f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']" 1624 ) 1625 counter = self.annotation_store.count_classes() 1626 for i in behaviors_dict.keys(): 1627 if counter[i] != 0: 1628 if confusion_type == "recall" or confusion_type is None: 1629 confusion_matrix[i, :] /= counter[i] 1630 else: 1631 confusion_matrix[:, i] /= counter[i] 1632 return confusion_matrix, list(behaviors_dict.values()), confusion_type 1633 1634 def annotation_class(self) -> str: 1635 """ 1636 Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon) 1637 1638 Returns 1639 ------- 1640 annotation_class : str 1641 the type of annotation 1642 """ 1643 1644 return self.annotation_store.annotation_class() 1645 1646 def set_normalization_stats(self, stats: Dict) -> None: 1647 """ 1648 Set the stats to normalize data at runtime 1649 1650 Parameters 1651 ---------- 1652 stats : dict 1653 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1654 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1655 """ 1656 1657 self.stats = stats 1658 1659 def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]: 1660 coords = self.input_store.get_original_coordinates() 1661 clips = set( 1662 [ 1663 self.input_store.get_clip_id(c) 1664 for c in coords 1665 if self.input_store.get_video_id(c) == video_id 1666 ] 1667 ) 1668 min_frames = {} 1669 max_frames = {} 1670 for clip in clips: 1671 start = self.input_store.get_clip_start(video_id, clip) 1672 end = start + self.input_store.get_clip_length(video_id, clip) 1673 min_frames[clip] = start 1674 max_frames[clip] = end - 1 1675 return min_frames, max_frames 1676 1677 def get_normalization_stats(self, skip_keys=None) -> Dict: 1678 """ 1679 Get mean and standard deviation for each key 1680 1681 Returns 1682 ------- 1683 stats : dict 1684 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1685 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1686 """ 1687 1688 stats = defaultdict(lambda: {}) 1689 sums = defaultdict(lambda: 0) 1690 if skip_keys is None: 1691 skip_keys = [] 1692 counter = defaultdict(lambda: 0) 1693 for sample in tqdm(self): 1694 for key, value in sample["input"].items(): 1695 key_name = key.split("---")[0] 1696 if key_name not in skip_keys: 1697 sums[key_name] += value[:, value.sum(0) != 0].sum(-1) 1698 counter[key_name] += torch.sum(value.sum(0) != 0) 1699 for key, value in sums.items(): 1700 stats[key]["mean"] = (value / counter[key]).unsqueeze(-1) 1701 sums = defaultdict(lambda: 0) 1702 for sample in tqdm(self): 1703 for key, value in sample["input"].items(): 1704 key_name = key.split("---")[0] 1705 if key_name not in skip_keys: 1706 sums[key_name] += ( 1707 (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2 1708 ).sum(-1) 1709 for key, value in sums.items(): 1710 stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key]) 1711 return stats
36class BehaviorDataset(Dataset, ABC): 37 """ 38 A generalized dataset class 39 40 Data and annotation are stored in separate InputStore and AnnotationStore objects; the dataset class 41 manages their interactions. 42 """ 43 44 def __init__( 45 self, 46 data_type: str, 47 annotation_type: str = "none", 48 ssl_transformations: List = None, 49 saved_data_path: str = None, 50 input_store: InputStore = None, 51 annotation_store: AnnotationStore = None, 52 only_load_annotated: bool = False, 53 recompute_annotation: bool = False, 54 # mask: str = None, 55 ids: List = None, 56 **data_parameters, 57 ) -> None: 58 """ 59 Parameters 60 ---------- 61 data_type : str 62 the data type (see available types by running BehaviorDataset.data_types()) 63 annotation_type : str 64 the annotation type (see available types by running BehaviorDataset.annotation_types()) 65 ssl_transformations : list 66 a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple 67 saved_data_path : str 68 the path to a pre-computed pickled dataset 69 input_store : InputStore 70 a pre-computed input store 71 annotation_store : AnnotationStore 72 a precomputed annotation store 73 only_load_annotated : bool 74 if True, the input files that don't have a matching annotation file will be disregarded 75 *data_parameters : dict 76 parameters to initialize the input and annotation stores 77 """ 78 79 mask = None 80 if len(data_parameters) == 0: 81 recompute_annotation = False 82 feature_extraction = data_parameters.get("feature_extraction") 83 if feature_extraction is not None and not issubclass( 84 options.input_stores[data_type], 85 options.feature_extractors[feature_extraction].input_store_class, 86 ): 87 raise ValueError( 88 f"The {feature_extraction} feature extractor does not work with " 89 f"the {data_type} data type, please choose a suclass of " 90 f"{options.feature_extractors[feature_extraction].input_store_class}" 91 ) 92 if ssl_transformations is None: 93 ssl_transformations = [] 94 self.ssl_transformations = ssl_transformations 95 self.input_type = data_type 96 self.annotation_type = annotation_type 97 self.stats = None 98 if mask is not None: 99 with open(mask, "rb") as f: 100 self.mask = pickle.load(f) 101 else: 102 self.mask = None 103 self.ids = ids 104 self.tag = None 105 self.return_unlabeled = None 106 # load saved key objects for annotation and input if they exist 107 input_key_objects, annotation_key_objects = None, None 108 if saved_data_path is not None: 109 if os.path.exists(saved_data_path): 110 with open(saved_data_path, "rb") as f: 111 input_key_objects, annotation_key_objects = pickle.load(f) 112 # if the input or the annotation store need to be created, generate the common video order 113 if len(data_parameters) > 0: 114 input_files = options.input_stores[data_type].get_file_ids( 115 **data_parameters 116 ) 117 annotation_files = options.annotation_stores[annotation_type].get_file_ids( 118 **data_parameters 119 ) 120 if only_load_annotated: 121 data_parameters["video_order"] = [ 122 x for x in input_files if x in annotation_files 123 ] 124 else: 125 data_parameters["video_order"] = input_files 126 if len(data_parameters["video_order"]) == 0: 127 raise RuntimeError( 128 "The length of file list is 0! Please check your data parameters!" 129 ) 130 data_parameters["mask"] = self.mask 131 # load or create the input store 132 ok = False 133 if input_store is not None: 134 self.input_store = input_store 135 ok = True 136 elif input_key_objects is not None: 137 try: 138 self.input_store = self._load_input_store(data_type, input_key_objects) 139 ok = True 140 except: 141 warnings.warn("Loading input store from key objects failed") 142 if not ok: 143 self.input_store = self._get_input_store( 144 data_type, deepcopy(data_parameters) 145 ) 146 # get the objects needed to create the annotation store (like a clip length dictionary) 147 annotation_objects = self.input_store.get_annotation_objects() 148 data_parameters.update(annotation_objects) 149 # load or create the annotation store 150 ok = False 151 if annotation_store is not None: 152 self.annotation_store = annotation_store 153 ok = True 154 elif ( 155 annotation_key_objects is not None 156 and mask is None 157 and not recompute_annotation 158 ): 159 try: 160 self.annotation_store = self._load_annotation_store( 161 annotation_type, annotation_key_objects 162 ) 163 ok = True 164 except: 165 warnings.warn("Loading annotation store from key objects failed") 166 if not ok: 167 self.annotation_store = self._get_annotation_store( 168 annotation_type, deepcopy(data_parameters) 169 ) 170 if ( 171 mask is None 172 and annotation_type != "none" 173 and not recompute_annotation 174 and ( 175 self.annotation_store.get_original_coordinates() 176 != self.input_store.get_original_coordinates() 177 ).any() 178 ): 179 raise RuntimeError( 180 "The clip orders in the annotation store and input store are different!" 181 ) 182 # filter the data based on data parameters 183 # print(f"1 {self.annotation_store.get_original_coordinates().shape=}") 184 # print(f"1 {self.input_store.get_original_coordinates().shape=}") 185 to_remove = self.annotation_store.filtered_indices() 186 if len(to_remove) > 0: 187 print( 188 f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples" 189 ) 190 if len(self.input_store) == len(self.annotation_store): 191 self.input_store.remove(to_remove) 192 self.annotation_store.remove(to_remove) 193 self.input_indices = list(range(len(self.input_store))) 194 self.annotation_indices = list(range(len(self.input_store))) 195 self.indices = list(range(len(self.input_store))) 196 # print(f'{data_parameters["video_order"]=}') 197 # print(f"{self.annotation_store.get_original_coordinates().shape=}") 198 # print(f"{self.input_store.get_original_coordinates().shape=}") 199 # count = 0 200 # for i, (x, y) in enumerate(zip( 201 # self.annotation_store.get_original_coordinates(), 202 # self.input_store.get_original_coordinates(), 203 # )): 204 # if (x != y).any(): 205 # count += 1 206 # print({i}) 207 # print(f"ann: {x}") 208 # print(f"inp: {y}") 209 # print("\n") 210 # if count > 50: 211 # break 212 if annotation_type != "none" and ( 213 self.annotation_store.get_original_coordinates().shape 214 != self.input_store.get_original_coordinates().shape 215 or ( 216 self.annotation_store.get_original_coordinates() 217 != self.input_store.get_original_coordinates() 218 ).any() 219 ): 220 raise RuntimeError( 221 "The clip orders in the annotation store and input store are different!" 222 ) 223 224 def __getitem__(self, item: int) -> Dict: 225 idx = self._get_idx(item) 226 input = deepcopy(self.input_store[idx]) 227 target = self.annotation_store[idx] 228 tag = self.input_store.get_tag(idx) 229 ssl_inputs, ssl_targets = self._get_SSL_targets(input) 230 batch = {"input": input} 231 for name, x in zip( 232 ["target", "ssl_inputs", "ssl_targets", "tag"], 233 [target, ssl_inputs, ssl_targets, tag], 234 ): 235 if x is not None: 236 batch[name] = x 237 batch["index"] = idx 238 if self.stats is not None: 239 for key in batch["input"].keys(): 240 key_name = key.split("---")[0] 241 if key_name in self.stats: 242 batch["input"][key][:, batch["input"][key].sum(0) != 0] = ( 243 (batch["input"][key] - self.stats[key_name]["mean"]) 244 / (self.stats[key_name]["std"] + 1e-7) 245 )[:, batch["input"][key].sum(0) != 0] 246 return batch 247 248 def __len__(self) -> int: 249 return len(self.indices) 250 # if self.annotation_type != "none": 251 # return self.annotation_store.get_len(return_unlabeled=self.return_unlabeled) 252 # else: 253 # return len(self.input_store) 254 255 def get_tags(self) -> List: 256 """ 257 Get a list of all meta tags 258 259 Returns 260 ------- 261 tags: List 262 a list of unique meta tag values 263 """ 264 265 return self.input_store.get_tags() 266 267 def save(self, save_path: str) -> None: 268 """ 269 Save the dictionary 270 271 Parameters 272 ---------- 273 save_path : str 274 the path where the pickled file will be stored 275 """ 276 277 input_obj = self.input_store.key_objects() 278 annotation_obj = self.annotation_store.key_objects() 279 with open(save_path, "wb") as f: 280 pickle.dump((input_obj, annotation_obj), f) 281 282 def to_ram(self) -> None: 283 """ 284 Transfer the dataset to RAM 285 """ 286 287 self.input_store.to_ram() 288 self.annotation_store.to_ram() 289 290 def generate_full_length_gt(self) -> Dict: 291 if self.annotation_class() == "exclusive_classification": 292 gt = torch.zeros((len(self), self.len_segment())) 293 else: 294 gt = torch.zeros( 295 (len(self), len(self.behaviors_dict()), self.len_segment()) 296 ) 297 for i in range(len(self)): 298 gt[i] = self.annotation_store[i] 299 return self.generate_full_length_prediction(gt) 300 301 def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict: 302 """ 303 Map predictions for the equal-length pieces to predictions for the original data 304 305 Probabilities are averaged over predictions on overlapping intervals. 306 307 Parameters 308 ---------- 309 predicted: torch.Tensor 310 a tensor of predicted probabilities of shape `(N, #classes, #frames)` 311 312 Returns 313 ------- 314 full_length_prediction : dict 315 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are 316 averaged probability tensors 317 """ 318 319 result = defaultdict(lambda: {}) 320 counter = defaultdict(lambda: {}) 321 coordinates = self.input_store.get_original_coordinates() 322 for coords, prediction in zip(coordinates, predicted): 323 l = self.input_store.get_clip_length_from_coords(coords) 324 video_name = self.input_store.get_video_id(coords) 325 clip_id = self.input_store.get_clip_id(coords) 326 start, end = self.input_store.get_clip_start_end(coords) 327 if clip_id not in result[video_name].keys(): 328 result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 329 counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 330 result[video_name][clip_id][..., start:end] += ( 331 prediction.squeeze()[..., : end - start].detach().cpu() 332 ) 333 counter[video_name][clip_id][..., start:end] += 1 334 for video_name in result: 335 for clip_id in result[video_name]: 336 result[video_name][clip_id] /= counter[video_name][clip_id] 337 result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100 338 result = dict(result) 339 return result 340 341 def find_valleys( 342 self, 343 predicted: Union[torch.Tensor, Dict], 344 threshold: float = 0.5, 345 min_frames: int = 0, 346 visibility_min_score: float = 0, 347 visibility_min_frac: float = 0, 348 main_class: int = 1, 349 low: bool = True, 350 predicted_error: torch.Tensor = None, 351 error_threshold: float = 0.5, 352 hysteresis: bool = False, 353 threshold_diff: float = None, 354 min_frames_error: int = None, 355 smooth_interval: int = 1, 356 cut_annotated: bool = False, 357 ) -> Dict: 358 """ 359 Find the intervals where the probability of a certain class is below or above a certain hard_threshold 360 361 Parameters 362 ---------- 363 predicted : torch.Tensor | dict 364 either a tensor of predictions for the data prompts or the output of 365 `BehaviorDataset.generate_full_length_prediction` 366 threshold : float, default 0.5 367 the main hard_threshold 368 min_frames : int, default 0 369 the minimum length of the intervals 370 visibility_min_score : float, default 0 371 the minimum visibility score in the intervals 372 visibility_min_frac : float, default 0 373 fraction of the interval that has to have the visibility score larger than visibility_score_thr 374 main_class : int, default 1 375 the index of the class the function is inspecting 376 low : bool, default True 377 if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above 378 predicted_error : torch.Tensor, optional 379 a tensor of error predictions for the data prompts 380 error_threshold : float, default 0.5 381 maximum possible probability of error at the intervals 382 hysteresis: bool, default False 383 if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff 384 threshold_diff: float, optional 385 the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes 386 min_frames_error: int, optional 387 if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames 388 389 Returns 390 ------- 391 valleys : dict 392 a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals 393 """ 394 395 result = defaultdict(lambda: []) 396 if type(predicted) is not dict: 397 predicted = self.generate_full_length_prediction(predicted) 398 if predicted_error is not None: 399 predicted_error = self.generate_full_length_prediction(predicted_error) 400 elif min_frames_error is not None and min_frames_error != 0: 401 # warnings.warn( 402 # f"The min_frames_error parameter is set to {min_frames_error} but no error prediction " 403 # f"is given! Setting min_frames_error to 0." 404 # ) 405 min_frames_error = 0 406 if low and hysteresis and threshold_diff is None: 407 raise ValueError( 408 "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff." 409 ) 410 if cut_annotated: 411 masked_intervals_dict = self.get_annotated_intervals() 412 else: 413 masked_intervals_dict = None 414 print("Valleys found:") 415 for v_id in predicted: 416 for clip_id in predicted[v_id].keys(): 417 if predicted_error is not None: 418 error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold 419 if min_frames_error is not None: 420 output, indices, counts = torch.unique_consecutive( 421 error_mask, return_inverse=True, return_counts=True 422 ) 423 wrong_indices = torch.where( 424 output * (counts < min_frames_error) 425 )[0] 426 if len(wrong_indices) > 0: 427 for i in wrong_indices: 428 error_mask[indices == i] = False 429 else: 430 error_mask = None 431 if masked_intervals_dict is not None: 432 masked_intervals = masked_intervals_dict[v_id][clip_id] 433 else: 434 masked_intervals = None 435 if not hysteresis: 436 res_indices_start, res_indices_end = apply_threshold( 437 predicted[v_id][clip_id][main_class, :], 438 threshold, 439 low, 440 error_mask, 441 min_frames, 442 smooth_interval, 443 masked_intervals, 444 ) 445 elif threshold_diff is not None: 446 if low: 447 soft_threshold = threshold + threshold_diff 448 else: 449 soft_threshold = threshold - threshold_diff 450 res_indices_start, res_indices_end = apply_threshold_hysteresis( 451 predicted[v_id][clip_id][main_class, :], 452 soft_threshold, 453 threshold, 454 low, 455 error_mask, 456 min_frames, 457 smooth_interval, 458 masked_intervals, 459 ) 460 else: 461 res_indices_start, res_indices_end = apply_threshold_max( 462 predicted[v_id][clip_id], 463 threshold, 464 main_class, 465 error_mask, 466 min_frames, 467 smooth_interval, 468 masked_intervals, 469 ) 470 start = self.input_store.get_clip_start(v_id, clip_id) 471 result[v_id] += [ 472 [i + start, j + start, clip_id] 473 for i, j in zip(res_indices_start, res_indices_end) 474 if self.input_store.get_visibility( 475 v_id, clip_id, i, j, visibility_min_score 476 ) 477 > visibility_min_frac 478 ] 479 result[v_id] = sorted(result[v_id]) 480 print(f" {v_id}: {len(result[v_id])}") 481 return dict(result) 482 483 def valleys_union(self, valleys_list) -> Dict: 484 """ 485 Find the intersection of two valleys dictionaries 486 487 Parameters 488 ---------- 489 valleys_list : list 490 a list of valleys dictionaries 491 492 Returns 493 ------- 494 intersection : dict 495 a new valleys dictionary with the intersection of the input intervals 496 """ 497 498 valleys_list = [x for x in valleys_list if x is not None] 499 if len(valleys_list) == 1: 500 return valleys_list[0] 501 elif len(valleys_list) == 0: 502 return {} 503 union = {} 504 keys_list = [set(valleys.keys()) for valleys in valleys_list] 505 keys = set.union(*keys_list) 506 for v_id in keys: 507 res = [] 508 clips_list = [ 509 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 510 ] 511 clips = set.union(*clips_list) 512 for clip_id in clips: 513 clip_intervals = [ 514 x 515 for valleys in valleys_list 516 for x in valleys[v_id] 517 if x[-1] == clip_id 518 ] 519 v_len = self.input_store.get_clip_length(v_id, clip_id) 520 arr = torch.zeros(v_len) 521 for start, end, _ in clip_intervals: 522 arr[start:end] += 1 523 output, indices, counts = torch.unique_consecutive( 524 arr > 0, return_inverse=True, return_counts=True 525 ) 526 long_indices = torch.where(output)[0] 527 res += [ 528 ( 529 (indices == i).nonzero(as_tuple=True)[0][0].item(), 530 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 531 clip_id, 532 ) 533 for i in long_indices 534 ] 535 union[v_id] = res 536 return union 537 538 def valleys_intersection(self, valleys_list) -> Dict: 539 """ 540 Find the intersection of two valleys dictionaries 541 542 Parameters 543 ---------- 544 valleys_list : list 545 a list of valleys dictionaries 546 547 Returns 548 ------- 549 intersection : dict 550 a new valleys dictionary with the intersection of the input intervals 551 """ 552 553 valleys_list = [x for x in valleys_list if x is not None] 554 if len(valleys_list) == 1: 555 return valleys_list[0] 556 elif len(valleys_list) == 0: 557 return {} 558 intersection = {} 559 keys_list = [set(valleys.keys()) for valleys in valleys_list] 560 keys = set.intersection(*keys_list) 561 for v_id in keys: 562 res = [] 563 clips_list = [ 564 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 565 ] 566 clips = set.intersection(*clips_list) 567 for clip_id in clips: 568 clip_intervals = [ 569 x 570 for valleys in valleys_list 571 for x in valleys[v_id] 572 if x[-1] == clip_id 573 ] 574 v_len = self.input_store.get_clip_length(v_id, clip_id) 575 arr = torch.zeros(v_len) 576 for start, end, _ in clip_intervals: 577 arr[start:end] += 1 578 output, indices, counts = torch.unique_consecutive( 579 arr, return_inverse=True, return_counts=True 580 ) 581 long_indices = torch.where(output == 2)[0] 582 res += [ 583 ( 584 (indices == i).nonzero(as_tuple=True)[0][0].item(), 585 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 586 clip_id, 587 ) 588 for i in long_indices 589 ] 590 intersection[v_id] = res 591 return intersection 592 593 def partition_train_test_val( 594 self, 595 use_test: float = 0, 596 split_path: str = None, 597 method: str = "random", 598 val_frac: float = 0, 599 test_frac: float = 0, 600 save_split: bool = False, 601 normalize: bool = False, 602 skip_normalization_keys: List = None, 603 stats: Dict = None, 604 ) -> Tuple: 605 """ 606 Partition the dataset into three new datasets 607 608 Parameters 609 ---------- 610 use_test : float, default 0 611 The fraction of the test dataset to be used in training without labels 612 split_path : str, optional 613 The path to load the split information from (if `'file'` method is used) and to save it to 614 (if `'save_split'` is `True`) 615 method : {'random', 'random:test-from-name', 'random:test-from-name:{name}', 616 'val-from-name:{val_name}:test-from-name:{test_name}', 617 'random:equalize:segments', 'random:equalize:videos', 618 'folders', 'time', 'time:strict', 'file'} 619 The partitioning method: 620 - `'random'`: sort videos into subsets randomly, 621 - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation 622 subsets randomly and create 623 the test subset from the video ids that start with a speific substring (`'test'` by default, or `name` 624 if provided), 625 - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but 626 making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain 627 occurences of the class get into the validation subset and `0.8 * test_frac` get into the test subset; 628 this in ensured for all classes in order of increasing number of occurences until the validation and 629 test subsets are full 630 - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test 631 subsets from the video ids that start with specific substrings (`val_name` for validation 632 and `test_name` for test) and sort all other videos into the training subset 633 - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets, 634 - `'time'`: split each video into training, validation and test subsequences, 635 - `'time:strict'`: split each video into validation, test and training subsequences 636 and throw out the last segments in validation and test (to get rid of overlaps), 637 - `'file'`: split according to a split file. 638 val_frac : float, default 0 639 The fraction of the dataset to be used in validation 640 test_frac : float, default 0 641 The fraction of the dataset to be used in test 642 save_split : bool, default False 643 Save a split file if True 644 645 Returns 646 ------- 647 train_dataset : BehaviorDataset 648 train dataset 649 650 val_dataset : BehaviorDataset 651 validation dataset 652 653 test_dataset : BehaviorDataset 654 test dataset 655 """ 656 657 train_indices, test_indices, val_indices = self._partition_indices( 658 split_path=split_path, 659 method=method, 660 val_frac=val_frac, 661 test_frac=test_frac, 662 save_split=save_split, 663 ) 664 val_dataset = self._create_new_dataset(val_indices) 665 test_dataset = self._create_new_dataset(test_indices) 666 train_dataset = self._create_new_dataset( 667 train_indices, ssl_indices=test_indices[: int(len(test_indices) * use_test)] 668 ) 669 print("Number of samples:") 670 print(f" validation:") 671 print(f" {val_dataset.count_classes()}") 672 print(f" training:") 673 print(f" {train_dataset.count_classes()}") 674 print(f" test:") 675 print(f" {test_dataset.count_classes()}") 676 if normalize: 677 if stats is None: 678 print("Computing normalization statistics...") 679 stats = train_dataset.get_normalization_stats(skip_normalization_keys) 680 else: 681 print("Setting loaded normalization statistics...") 682 train_dataset.set_normalization_stats(stats) 683 val_dataset.set_normalization_stats(stats) 684 test_dataset.set_normalization_stats(stats) 685 return train_dataset, test_dataset, val_dataset 686 687 def class_weights(self, proportional=False) -> List: 688 """ 689 Calculate class weights in inverse proportion to number of samples 690 Returns 691 ------- 692 weights: list 693 a list of class weights 694 """ 695 696 items = sorted( 697 [ 698 (k, v) 699 for k, v in self.annotation_store.count_classes().items() 700 if k != -100 701 ] 702 ) 703 if self.annotation_store.annotation_class() == "exclusive_classification": 704 if not proportional: 705 numerator = len(self.annotation_store) 706 else: 707 numerator = max([x[1] for x in items]) 708 weights = [numerator / (v + 1e-7) for _, v in items] 709 else: 710 items_zero = sorted( 711 [ 712 (k, v) 713 for k, v in self.annotation_store.count_classes(zeros=True).items() 714 if k != -100 715 ] 716 ) 717 if not proportional: 718 numerators = defaultdict(lambda: len(self.annotation_store)) 719 else: 720 numerators = { 721 item_one[0]: max(item_one[1], item_zero[1]) 722 for item_one, item_zero in zip(items, items_zero) 723 } 724 weights = {} 725 weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero] 726 weights[1] = [numerators[k] / (v + 1e-7) for k, v in items] 727 return weights 728 729 def boundary_class_weight(self): 730 if self.annotation_type != "none": 731 f = self.annotation_store.data.flatten() 732 _, inv = torch.unique_consecutive(f, return_inverse=True) 733 boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape( 734 self.annotation_store.data.shape 735 ) 736 boundary[..., 0] = 0 737 cnt = Counter(boundary.flatten().numpy()) 738 return cnt[1] / cnt[0] 739 else: 740 return 0 741 742 def count_classes(self, bouts: bool = False) -> Dict: 743 """ 744 Get a class counter dictionary 745 746 Parameters 747 ---------- 748 bouts : bool, default False 749 if `True`, instead of frame counts segment counts are returned 750 751 Returns 752 ------- 753 count_dictionary : dict 754 a dictionary with class indices as keys and frame or bout counts as values 755 """ 756 757 return self.annotation_store.count_classes(bouts=bouts) 758 759 def behaviors_dict(self) -> Dict: 760 """ 761 Get a behavior dictionary 762 763 Returns 764 ------- 765 dict 766 behavior dictionary 767 """ 768 769 return self.annotation_store.behaviors_dict() 770 771 def bodyparts_order(self) -> List: 772 try: 773 return self.input_store.get_bodyparts() 774 except: 775 raise RuntimeError( 776 f"The {self.input_type} input store does not have bodyparts implemented!" 777 ) 778 779 def features_shape(self) -> Dict: 780 """ 781 Get the shapes of the input features 782 783 Returns 784 ------- 785 shapes : Dict 786 a dictionary with the shapes of the features 787 """ 788 789 sample = self.input_store[0] 790 shapes = {k: v.shape for k, v in sample.items()} 791 # for key, value in shapes.items(): 792 # print(f'{key}: {value}') 793 return shapes 794 795 def num_classes(self) -> int: 796 """ 797 Get the number of classes in the data 798 799 Returns 800 ------- 801 num_classes : int 802 the number of classes 803 """ 804 805 return len(self.annotation_store.behaviors_dict()) 806 807 def len_segment(self) -> int: 808 """ 809 Get the segment length in the data 810 811 Returns 812 ------- 813 len_segment : int 814 the segment length 815 """ 816 817 sample = self.input_store[0] 818 key = list(sample.keys())[0] 819 return sample[key].shape[-1] 820 821 def set_ssl_transformations(self, ssl_transformations: List) -> None: 822 """ 823 Set new SSL transformations 824 825 Parameters 826 ---------- 827 ssl_transformations : list 828 a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets 829 lists 830 """ 831 832 self.ssl_transformations = ssl_transformations 833 834 @classmethod 835 def new(cls, *args, **kwargs): 836 """ 837 Create a new object of the same class 838 839 Returns 840 ------- 841 new_instance: BehaviorDataset 842 a new instance of the same class 843 """ 844 845 return cls(*args, **kwargs) 846 847 @classmethod 848 def get_parameters(cls, data_type: str, annotation_type: str) -> List: 849 """ 850 Get parameters necessary for initialization 851 852 Parameters 853 ---------- 854 data_type : str 855 the data type 856 annotation_type : str 857 the annotation type 858 """ 859 860 input_features = options.input_stores[data_type].get_parameters() 861 annotation_features = options.annotation_stores[ 862 annotation_type 863 ].get_parameters() 864 self_features = inspect.getfullargspec(cls.__init__).args 865 return self_features + input_features + annotation_features 866 867 @staticmethod 868 def data_types() -> List: 869 """ 870 List available data types 871 872 Returns 873 ------- 874 data_types : list 875 available data types 876 """ 877 878 return list(options.input_stores.keys()) 879 880 @staticmethod 881 def annotation_types() -> List: 882 """ 883 List available annotation types 884 885 Returns 886 ------- 887 annotation_types : list 888 available annotation types 889 """ 890 891 return list(options.annotation_stores.keys()) 892 893 def _get_SSL_targets(self, input: Dict) -> Tuple[List, List]: 894 """ 895 Get the SSL inputs and targets from a sample dictionary 896 """ 897 898 ssl_inputs = [] 899 ssl_targets = [] 900 for transform in self.ssl_transformations: 901 ssl_input, ssl_target = transform(copy(input)) 902 ssl_inputs.append(ssl_input) 903 ssl_targets.append(ssl_target) 904 return ssl_inputs, ssl_targets 905 906 def _create_new_dataset(self, indices: List, ssl_indices: List = None): 907 """ 908 Create a subsample of the dataset, with samples at ssl_indices losing the annotation 909 """ 910 911 if ssl_indices is None: 912 ssl_indices = [] 913 input_store = self.input_store.create_subsample(indices, ssl_indices) 914 annotation_store = self.annotation_store.create_subsample(indices, ssl_indices) 915 new = self.new( 916 data_type=self.input_type, 917 annotation_type=self.annotation_type, 918 ssl_transformations=self.ssl_transformations, 919 annotation_store=annotation_store, 920 input_store=input_store, 921 ids=list(indices) + list(ssl_indices), 922 recompute_annotation=False, 923 ) 924 return new 925 926 def _load_input_store(self, data_type: str, key_objects: Tuple) -> InputStore: 927 """ 928 Load input store from key objects 929 """ 930 931 input_store = options.input_stores[data_type](key_objects=key_objects) 932 return input_store 933 934 def _load_annotation_store( 935 self, annotation_type: str, key_objects: Tuple 936 ) -> AnnotationStore: 937 """ 938 Load annotation store from key objects 939 """ 940 941 annotation_store = options.annotation_stores[annotation_type]( 942 key_objects=key_objects 943 ) 944 return annotation_store 945 946 def _get_input_store(self, data_type: str, data_parameters: Dict) -> InputStore: 947 """ 948 Create input store from parameters 949 """ 950 951 data_parameters["key_objects"] = None 952 input_store = options.input_stores[data_type](**data_parameters) 953 return input_store 954 955 def _get_annotation_store( 956 self, annotation_type: str, data_parameters: Dict 957 ) -> AnnotationStore: 958 """ 959 Create annotation store from parameters 960 """ 961 962 annotation_store = options.annotation_stores[annotation_type](**data_parameters) 963 return annotation_store 964 965 def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None: 966 """ 967 Set the parameters that change the subset that is returned at `__getitem__` 968 969 Parameters 970 ---------- 971 unlabeled : bool 972 a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and 973 all if `None` 974 tag : int 975 if not `None`, only samples with this meta tag will be returned 976 """ 977 978 if unlabeled != self.return_unlabeled: 979 self.annotation_indices = self.annotation_store.get_indices(unlabeled) 980 self.return_unlabeled = unlabeled 981 if tag != self.tag: 982 self.input_indices = self.input_store.get_indices(tag) 983 self.tag = tag 984 self.indices = [x for x in self.annotation_indices if x in self.input_indices] 985 986 def _get_idx(self, index: int) -> int: 987 """ 988 Get index in full dataset 989 """ 990 991 return self.indices[index] 992 993 # return self.annotation_store.get_idx( 994 # index, return_unlabeled=self.return_unlabeled 995 # ) 996 997 def _partition_indices( 998 self, 999 split_path: str = None, 1000 method: str = "random", 1001 val_frac: float = 0, 1002 test_frac: float = 0, 1003 save_split: bool = False, 1004 ) -> Tuple[List, List, List]: 1005 """ 1006 Partition indices into train, validation, test subsets 1007 """ 1008 1009 if self.mask is not None: 1010 val_indices = self.mask["val_ids"] 1011 train_indices = [x for x in range(len(self)) if x not in val_indices] 1012 test_indices = [] 1013 elif method == "random": 1014 videos = np.array(self.input_store.get_video_id_order()) 1015 all_videos = list(set(videos)) 1016 if len(all_videos) == 1: 1017 warnings.warn( 1018 "There is only one video in the dataset, so train/val/test split is done on segments; " 1019 'that might lead to overlaps, please consider using "time" or "time:strict" as the ' 1020 "partitioning method instead" 1021 ) 1022 # Quick fix for single video: the problem with this is that the segments can overlap 1023 # length = int(self.input_store.get_original_coordinates()[-1][1]) # number of segments 1024 length = len(self.input_store.get_original_coordinates()) 1025 val_len = int(val_frac * length) 1026 test_len = int(test_frac * length) 1027 all_indices = np.random.choice(np.arange(length), length, replace=False) 1028 val_indices = all_indices[:val_len] 1029 test_indices = all_indices[val_len : val_len + test_len] 1030 train_indices = all_indices[val_len + test_len :] 1031 coords = self.input_store.get_original_coordinates() 1032 if save_split: 1033 self._save_partition( 1034 coords[train_indices], 1035 coords[val_indices], 1036 coords[test_indices], 1037 split_path, 1038 coords=True, 1039 ) 1040 else: 1041 length = len(all_videos) 1042 val_len = int(val_frac * length) 1043 test_len = int(test_frac * length) 1044 validation = all_videos[:val_len] 1045 test = all_videos[val_len : val_len + test_len] 1046 training = all_videos[val_len + test_len :] 1047 train_indices = np.where(np.isin(videos, training))[0] 1048 val_indices = np.where(np.isin(videos, validation))[0] 1049 test_indices = np.where(np.isin(videos, test))[0] 1050 if save_split: 1051 self._save_partition(training, validation, test, split_path) 1052 elif method.startswith("random:equalize"): 1053 counter = self.count_classes() 1054 counter = sorted(list([(v, k) for k, v in counter.items()])) 1055 classes = [x[1] for x in counter] 1056 indicator = {c: [] for c in classes} 1057 if method.endswith("videos"): 1058 videos = np.array(self.input_store.get_video_id_order()) 1059 all_videos = list(set(videos)) 1060 total_len = len(all_videos) 1061 for video_id in all_videos: 1062 video_coords = np.where(videos == video_id)[0] 1063 ann = torch.cat( 1064 [self.annotation_store[i] for i in video_coords], dim=-1 1065 ) 1066 for c in classes: 1067 if self.annotation_class() == "nonexclusive_classification": 1068 indicator[c].append(torch.sum(ann[c] == 1) > 0) 1069 elif self.annotation_class() == "exclusive_classification": 1070 indicator[c].append(torch.sum(ann == c) > 0) 1071 else: 1072 raise ValueError( 1073 f"The random:equalize partition method is not implemented" 1074 f"for the {self.annotation_class()} annotation class" 1075 ) 1076 elif method.endswith("segments"): 1077 total_len = len(self) 1078 for ann in self.annotation_store: 1079 for c in classes: 1080 if self.annotation_class() == "nonexclusive_classification": 1081 indicator[c].append(torch.sum(ann[c] == 1) > 0) 1082 elif self.annotation_class() == "exclusive_classification": 1083 indicator[c].append(torch.sum(ann == c) > 0) 1084 else: 1085 raise ValueError( 1086 f"The random:equalize partition method is not implemented" 1087 f"for the {self.annotation_class()} annotation class" 1088 ) 1089 else: 1090 values = [] 1091 for v in options.partition_methods.values(): 1092 values += v 1093 raise ValueError( 1094 f"The {method} partition method is not recognized; please choose from {values}" 1095 ) 1096 val_indices = [] 1097 test_indices = [] 1098 for c in classes: 1099 indicator[c] = np.array(indicator[c]) 1100 ind = np.where(indicator[c])[0] 1101 np.random.shuffle(ind) 1102 c_sum = len(ind) 1103 in_val = np.sum(indicator[c][val_indices]) 1104 in_test = np.sum(indicator[c][test_indices]) 1105 while ( 1106 len(val_indices) < val_frac * total_len 1107 and in_val < val_frac * c_sum * 0.8 1108 ): 1109 first, ind = ind[0], ind[1:] 1110 val_indices = list(set(val_indices).union({first})) 1111 in_val = np.sum(indicator[c][val_indices]) 1112 while ( 1113 len(test_indices) < test_frac * total_len 1114 and in_test < test_frac * c_sum * 0.8 1115 ): 1116 first, ind = ind[0], ind[1:] 1117 test_indices = list(set(test_indices).union({first})) 1118 in_test = np.sum(indicator[c][test_indices]) 1119 if len(val_indices) < int(val_frac * total_len): 1120 left_val = int(val_frac * total_len) - len(val_indices) 1121 else: 1122 left_val = 0 1123 if len(test_indices) < int(test_frac * total_len): 1124 left_test = int(test_frac * total_len) - len(test_indices) 1125 else: 1126 left_test = 0 1127 indicator = np.ones(total_len) 1128 indicator[val_indices] = 0 1129 indicator[test_indices] = 0 1130 ind = np.where(indicator)[0] 1131 np.random.shuffle(ind) 1132 val_indices += list(ind[:left_val]) 1133 test_indices += list(ind[left_val : left_val + left_test]) 1134 train_indices = list(ind[left_val + left_test :]) 1135 if save_split: 1136 if method.endswith("segments"): 1137 coords = self.input_store.get_original_coordinates() 1138 self._save_partition( 1139 coords[train_indices], 1140 coords[val_indices], 1141 coords[test_indices], 1142 coords[split_path], 1143 coords=True, 1144 ) 1145 else: 1146 all_videos = np.array(all_videos) 1147 validation = all_videos[val_indices] 1148 test = all_videos[test_indices] 1149 training = all_videos[train_indices] 1150 self._save_partition(training, validation, test, split_path) 1151 elif method.startswith("random:test-from-name"): 1152 split = method.split(":") 1153 if len(split) > 2: 1154 test_name = split[-1] 1155 else: 1156 test_name = "test" 1157 videos = np.array(self.input_store.get_video_id_order()) 1158 all_videos = list(set(videos)) 1159 test = [] 1160 train_videos = [] 1161 for x in all_videos: 1162 if x.startswith(test_name): 1163 test.append(x) 1164 else: 1165 train_videos.append(x) 1166 length = len(train_videos) 1167 val_len = int(val_frac * length) 1168 validation = train_videos[:val_len] 1169 training = train_videos[val_len:] 1170 train_indices = np.where(np.isin(videos, training))[0] 1171 val_indices = np.where(np.isin(videos, validation))[0] 1172 test_indices = np.where(np.isin(videos, test))[0] 1173 if save_split: 1174 self._save_partition(training, validation, test, split_path) 1175 elif method.startswith("val-from-name"): 1176 split = method.split(":") 1177 if split[2] != "test-from-name": 1178 raise ValueError( 1179 f"The {method} partition method is not recognized, please choose from {options.partition_methods}" 1180 ) 1181 val_name = split[1] 1182 test_name = split[-1] 1183 videos = np.array(self.input_store.get_video_id_order()) 1184 all_videos = list(set(videos)) 1185 test = [] 1186 validation = [] 1187 training = [] 1188 for x in all_videos: 1189 if x.startswith(test_name): 1190 test.append(x) 1191 elif x.startswith(val_name): 1192 validation.append(x) 1193 else: 1194 training.append(x) 1195 train_indices = np.where(np.isin(videos, training))[0] 1196 val_indices = np.where(np.isin(videos, validation))[0] 1197 test_indices = np.where(np.isin(videos, test))[0] 1198 elif method == "folders": 1199 folders = np.array(self.input_store.get_folder_order()) 1200 videos = np.array(self.input_store.get_video_id_order()) 1201 train_indices = np.where(np.isin(folders, ["training", "train"]))[0] 1202 if np.sum(np.isin(folders, ["validation", "val"])) > 0: 1203 val_indices = np.where(np.isin(folders, ["validation", "val"]))[0] 1204 else: 1205 train_videos = list(set(videos[train_indices])) 1206 val_len = int(val_frac * len(train_videos)) 1207 validation = train_videos[:val_len] 1208 training = train_videos[val_len:] 1209 train_indices = np.where(np.isin(videos, training))[0] 1210 val_indices = np.where(np.isin(videos, validation))[0] 1211 test_indices = np.where(folders == "test")[0] 1212 if save_split: 1213 self._save_partition( 1214 list(set(videos[train_indices])), 1215 list(set(videos[val_indices])), 1216 list(set(videos[test_indices])), 1217 split_path, 1218 ) 1219 elif method.startswith("leave-one-out"): 1220 n = int(method.split(":")[-1]) 1221 videos = np.array(self.input_store.get_video_id_order()) 1222 all_videos = sorted(list(set(videos))) 1223 validation = [all_videos.pop(n)] 1224 training = all_videos 1225 train_indices = np.where(np.isin(videos, training))[0] 1226 val_indices = np.where(np.isin(videos, validation))[0] 1227 test_indices = np.array([]) 1228 elif method.startswith("time"): 1229 if method.endswith("strict"): 1230 len_segment = self.len_segment() 1231 # TODO: change this 1232 step = self.input_store.step 1233 num_removed = len_segment // step 1234 else: 1235 num_removed = 0 1236 videos = np.array(self.input_store.get_video_id_order()) 1237 all_videos = set(videos) 1238 train_indices = [] 1239 val_indices = [] 1240 test_indices = [] 1241 start = 0 1242 if len(method.split(":")) > 1 and method.split(":")[1] == "start-from": 1243 start = float(method.split(":")[2]) 1244 for video_id in all_videos: 1245 video_indices = np.where(videos == video_id)[0] 1246 val_len = int(val_frac * len(video_indices)) 1247 test_len = int(test_frac * len(video_indices)) 1248 start_pos = int(start * len(video_indices)) 1249 all_ind = np.ones(len(video_indices)) 1250 val_indices += list(video_indices[start_pos : start_pos + val_len]) 1251 all_ind[start_pos : start_pos + val_len] = 0 1252 if start_pos + val_len > len(video_indices): 1253 p = start_pos + val_len - len(video_indices) 1254 val_indices += list(video_indices[:p]) 1255 all_ind[:p] = 0 1256 else: 1257 p = start_pos + val_len 1258 test_indices += list(video_indices[p : p + test_len]) 1259 all_ind[p : p + test_len] = 0 1260 if p + test_len > len(video_indices): 1261 p = test_len + p - len(video_indices) 1262 test_indices += list(video_indices[:p]) 1263 all_ind[:p] = 0 1264 train_indices += list(video_indices[all_ind > 0]) 1265 for _ in range(num_removed): 1266 if len(val_indices) > 0: 1267 val_indices.pop(-1) 1268 if len(test_indices) > 0: 1269 test_indices.pop(-1) 1270 if start > 0 and len(train_indices) > 0: 1271 train_indices.pop(-1) 1272 elif method == "file": 1273 if split_path is None: 1274 raise ValueError( 1275 'You need to either set split_path or change partition method ("file" requires a file)' 1276 ) 1277 with open(split_path) as f: 1278 train_line = f.readline() 1279 line = f.readline() 1280 while not line.startswith("Validation") and not line.startswith( 1281 "Validataion" 1282 ): 1283 line = f.readline() 1284 if line.startswith("Validation"): 1285 validation = [] 1286 test = [] 1287 while True: 1288 line = f.readline() 1289 if line.startswith("Test") or len(line) == 0: 1290 break 1291 validation.append(line.rstrip()) 1292 while True: 1293 line = f.readline() 1294 if len(line) == 0: 1295 break 1296 test.append(line.rstrip()) 1297 type = train_line[9:-2] 1298 else: 1299 line = f.readline() 1300 validation = line.rstrip(",\n ").split(", ") 1301 test = [] 1302 type = "videos" 1303 if type == "videos": 1304 videos = np.array(self.input_store.get_video_id_order()) 1305 val_indices = np.where(np.isin(videos, validation))[0] 1306 test_indices = np.where(np.isin(videos, test))[0] 1307 elif type == "coords": 1308 coords = self.input_store.get_original_coordinates() 1309 video_ids = self.input_store.get_video_id_order() 1310 clip_ids = [self.input_store.get_clip_id(coord) for coord in coords] 1311 starts, ends = zip( 1312 *[self.input_store.get_clip_start_end(coord) for coord in coords] 1313 ) 1314 coords = np.array( 1315 [ 1316 f"{video_id}---{clip_id}---{start}-{end}" 1317 for video_id, clip_id, start, end in zip( 1318 video_ids, clip_ids, starts, ends 1319 ) 1320 ] 1321 ) 1322 val_indices = np.where(np.isin(coords, validation))[0] 1323 test_indices = np.where(np.isin(coords, test))[0] 1324 else: 1325 raise ValueError("The split path has unrecognized format!") 1326 all = np.ones(len(self)) 1327 all[val_indices] = 0 1328 all[test_indices] = 0 1329 train_indices = np.where(all)[0] 1330 else: 1331 raise ValueError( 1332 f"The {method} partition is not recognized, please choose from {options.partition_methods}" 1333 ) 1334 return sorted(train_indices), sorted(test_indices), sorted(val_indices) 1335 1336 def _save_partition( 1337 self, 1338 training: List, 1339 validation: List, 1340 test: List, 1341 split_path: str, 1342 coords: bool = False, 1343 ) -> None: 1344 """ 1345 Save a split file 1346 """ 1347 1348 if coords: 1349 name = "coords" 1350 training_coords = [] 1351 val_coords = [] 1352 test_coords = [] 1353 for coord in training: 1354 video_id = self.input_store.get_video_id(coord) 1355 clip_id = self.input_store.get_clip_id(coord) 1356 start, end = self.input_store.get_clip_start_end(coord) 1357 training_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1358 for coord in validation: 1359 video_id = self.input_store.get_video_id(coord) 1360 clip_id = self.input_store.get_clip_id(coord) 1361 start, end = self.input_store.get_clip_start_end(coord) 1362 val_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1363 for coord in test: 1364 video_id = self.input_store.get_video_id(coord) 1365 clip_id = self.input_store.get_clip_id(coord) 1366 start, end = self.input_store.get_clip_start_end(coord) 1367 test_coords.append(f"{video_id}---{clip_id}---{start}-{end}") 1368 training, validation, test = training_coords, val_coords, test_coords 1369 else: 1370 name = "videos" 1371 if split_path is not None: 1372 with open(split_path, "w") as f: 1373 f.write(f"Training {name}:\n") 1374 for x in training: 1375 f.write(x + "\n") 1376 f.write(f"Validation {name}:\n") 1377 for x in validation: 1378 f.write(x + "\n") 1379 f.write(f"Test {name}:\n") 1380 for x in test: 1381 f.write(x + "\n") 1382 1383 def _get_intervals_from_ind(self, frame_indices: np.ndarray): 1384 """ 1385 Get a list of intervals from a list of frame indices 1386 1387 Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`. 1388 1389 Parameters 1390 ---------- 1391 frame_indices : np.ndarray 1392 a list of frame indices 1393 1394 Returns 1395 ------- 1396 intervals : list 1397 a list of interval boundaries 1398 """ 1399 1400 masked_intervals = [] 1401 breaks = np.where(np.diff(frame_indices) != 1)[0] 1402 if len(frame_indices) > 0: 1403 start = frame_indices[0] 1404 for k in breaks: 1405 masked_intervals.append([start, frame_indices[k] + 1]) 1406 start = frame_indices[k + 1] 1407 masked_intervals.append([start, frame_indices[-1] + 1]) 1408 return masked_intervals 1409 1410 def get_intervals(self) -> Tuple[dict, Optional[list]]: 1411 """ 1412 Get a list of intervals covered by the dataset in the original coordinates 1413 1414 Returns 1415 ------- 1416 intervals : dict 1417 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1418 values are lists of the intervals in `[start, end]` format 1419 """ 1420 1421 counter = defaultdict(lambda: {}) 1422 coordinates = self.input_store.get_original_coordinates() 1423 for coords in coordinates: 1424 l = self.input_store.get_clip_length_from_coords(coords) 1425 video_name = self.input_store.get_video_id(coords) 1426 clip_id = self.input_store.get_clip_id(coords) 1427 start, end = self.input_store.get_clip_start_end(coords) 1428 if clip_id not in counter[video_name]: 1429 counter[video_name][clip_id] = np.zeros(l) 1430 counter[video_name][clip_id][start:end] = 1 1431 result = {video_name: {} for video_name in counter} 1432 for video_name in counter: 1433 for clip_id in counter[video_name]: 1434 result[video_name][clip_id] = self._get_intervals_from_ind( 1435 np.where(counter[video_name][clip_id])[0] 1436 ) 1437 return result, self.ids 1438 1439 def get_unannotated_intervals(self, first_intervals=None) -> Dict: 1440 """ 1441 Get a list of intervals in the original coordinates where there is no annotation 1442 1443 Returns 1444 ------- 1445 intervals : dict 1446 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1447 values are lists of the intervals in `[start, end]` format 1448 """ 1449 1450 counter_value = 2 1451 if first_intervals is None: 1452 first_intervals = defaultdict(lambda: defaultdict(lambda: [])) 1453 counter_value = 1 1454 counter = defaultdict(lambda: {}) 1455 coordinates = self.input_store.get_original_coordinates() 1456 for i, coords in enumerate(coordinates): 1457 l = self.input_store.get_clip_length_from_coords(coords) 1458 ann = self.annotation_store[i] 1459 if ( 1460 self.annotation_store.annotation_class() 1461 == "nonexclusive_classification" 1462 ): 1463 ann = ann[0, :] 1464 video_name = self.input_store.get_video_id(coords) 1465 clip_id = self.input_store.get_clip_id(coords) 1466 start, end = self.input_store.get_clip_start_end(coords) 1467 if clip_id not in counter[video_name]: 1468 counter[video_name][clip_id] = np.ones(l) 1469 counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int() 1470 result = {video_name: {} for video_name in counter} 1471 for video_name in counter: 1472 for clip_id in counter[video_name]: 1473 for start, end in first_intervals[video_name][clip_id]: 1474 counter[video_name][clip_id][start:end] += 1 1475 result[video_name][clip_id] = self._get_intervals_from_ind( 1476 np.where(counter[video_name][clip_id] == counter_value)[0] 1477 ) 1478 return result 1479 1480 def get_annotated_intervals(self) -> Dict: 1481 """ 1482 Get a list of intervals in the original coordinates where there is no annotation 1483 1484 Returns 1485 ------- 1486 intervals : dict 1487 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1488 values are lists of the intervals in `[start, end]` format 1489 """ 1490 1491 if self.annotation_type == "none": 1492 return [] 1493 counter_value = 1 1494 counter = defaultdict(lambda: {}) 1495 coordinates = self.input_store.get_original_coordinates() 1496 for i, coords in enumerate(coordinates): 1497 l = self.input_store.get_clip_length_from_coords(coords) 1498 ann = self.annotation_store[i] 1499 video_name = self.input_store.get_video_id(coords) 1500 clip_id = self.input_store.get_clip_id(coords) 1501 start, end = self.input_store.get_clip_start_end(coords) 1502 if clip_id not in counter[video_name]: 1503 counter[video_name][clip_id] = np.zeros(l) 1504 if ( 1505 self.annotation_store.annotation_class() 1506 == "nonexclusive_classification" 1507 ): 1508 counter[video_name][clip_id][start:end] = ( 1509 torch.sum(ann[:, : end - start] != -100, dim=0) > 0 1510 ).int() 1511 else: 1512 counter[video_name][clip_id][start:end] = ( 1513 ann[: end - start] != -100 1514 ).int() 1515 result = {video_name: {} for video_name in counter} 1516 for video_name in counter: 1517 for clip_id in counter[video_name]: 1518 result[video_name][clip_id] = self._get_intervals_from_ind( 1519 np.where(counter[video_name][clip_id] == counter_value)[0] 1520 ) 1521 return result 1522 1523 def get_ids(self) -> Dict: 1524 """ 1525 Get a dictionary of all clip ids in the dataset 1526 1527 Returns 1528 ------- 1529 ids : dict 1530 a dictionary where keys are video ids and values are lists of clip ids 1531 """ 1532 1533 coordinates = self.input_store.get_original_coordinates() 1534 video_ids = np.array(self.input_store.get_video_id_order()) 1535 id_set = set(video_ids) 1536 result = {} 1537 for video_id in id_set: 1538 coords = coordinates[video_ids == video_id] 1539 clip_ids = list({self.input_store.get_clip_id(c) for c in coords}) 1540 result[video_id] = clip_ids 1541 return result 1542 1543 def get_len(self, video_id: str, clip_id: str) -> int: 1544 """ 1545 Get the length of a specific clip 1546 1547 Parameters 1548 ---------- 1549 video_id : str 1550 the video id 1551 clip_id : str 1552 the clip id 1553 1554 Returns 1555 ------- 1556 length : int 1557 the length 1558 """ 1559 1560 return self.input_store.get_clip_length(video_id, clip_id) 1561 1562 def get_confusion_matrix( 1563 self, prediction: torch.Tensor, confusion_type: str = "recall" 1564 ) -> Tuple[ndarray, list]: 1565 """ 1566 Get a confusion matrix 1567 1568 Parameters 1569 ---------- 1570 prediction : torch.Tensor 1571 a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)` 1572 confusion_type : {"recall", "precision"} 1573 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 1574 into account, and if `type` is `"precision"`, only false negatives 1575 1576 Returns 1577 ------- 1578 confusion_matrix : np.ndarray 1579 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 1580 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 1581 `N_i` is the number of frames that have the i-th label in the ground truth 1582 classes : list 1583 a list of classes 1584 """ 1585 1586 behaviors_dict = self.annotation_store.behaviors_dict() 1587 num_behaviors = len(behaviors_dict) 1588 confusion_matrix = np.zeros((num_behaviors, num_behaviors)) 1589 if self.annotation_store.annotation_class() == "exclusive_classification": 1590 exclusive = True 1591 confusion_type = None 1592 elif self.annotation_store.annotation_class() == "nonexclusive_classification": 1593 exclusive = False 1594 else: 1595 raise RuntimeError( 1596 f"The {self.annotation_store.annotation_class()} annotation class is not recognized!" 1597 ) 1598 for ann, p in zip(self.annotation_store, prediction): 1599 if exclusive: 1600 class_prediction = torch.max(p, dim=0)[1] 1601 for i in behaviors_dict.keys(): 1602 for j in behaviors_dict.keys(): 1603 confusion_matrix[i, j] += int( 1604 torch.sum(class_prediction[ann == i] == j) 1605 ) 1606 else: 1607 class_prediction = (p > 0.5).int() 1608 for i in behaviors_dict.keys(): 1609 for j in behaviors_dict.keys(): 1610 if confusion_type == "recall": 1611 pred = deepcopy(class_prediction[j]) 1612 if i != j: 1613 pred[ann[j] == 1] = 0 1614 confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1])) 1615 elif confusion_type == "precision": 1616 annotation = deepcopy(ann[j]) 1617 if i != j: 1618 annotation[class_prediction[j] == 1] = 0 1619 confusion_matrix[i, j] += int( 1620 torch.sum(annotation[class_prediction[i] == 1]) 1621 ) 1622 else: 1623 raise ValueError( 1624 f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']" 1625 ) 1626 counter = self.annotation_store.count_classes() 1627 for i in behaviors_dict.keys(): 1628 if counter[i] != 0: 1629 if confusion_type == "recall" or confusion_type is None: 1630 confusion_matrix[i, :] /= counter[i] 1631 else: 1632 confusion_matrix[:, i] /= counter[i] 1633 return confusion_matrix, list(behaviors_dict.values()), confusion_type 1634 1635 def annotation_class(self) -> str: 1636 """ 1637 Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon) 1638 1639 Returns 1640 ------- 1641 annotation_class : str 1642 the type of annotation 1643 """ 1644 1645 return self.annotation_store.annotation_class() 1646 1647 def set_normalization_stats(self, stats: Dict) -> None: 1648 """ 1649 Set the stats to normalize data at runtime 1650 1651 Parameters 1652 ---------- 1653 stats : dict 1654 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1655 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1656 """ 1657 1658 self.stats = stats 1659 1660 def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]: 1661 coords = self.input_store.get_original_coordinates() 1662 clips = set( 1663 [ 1664 self.input_store.get_clip_id(c) 1665 for c in coords 1666 if self.input_store.get_video_id(c) == video_id 1667 ] 1668 ) 1669 min_frames = {} 1670 max_frames = {} 1671 for clip in clips: 1672 start = self.input_store.get_clip_start(video_id, clip) 1673 end = start + self.input_store.get_clip_length(video_id, clip) 1674 min_frames[clip] = start 1675 max_frames[clip] = end - 1 1676 return min_frames, max_frames 1677 1678 def get_normalization_stats(self, skip_keys=None) -> Dict: 1679 """ 1680 Get mean and standard deviation for each key 1681 1682 Returns 1683 ------- 1684 stats : dict 1685 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1686 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1687 """ 1688 1689 stats = defaultdict(lambda: {}) 1690 sums = defaultdict(lambda: 0) 1691 if skip_keys is None: 1692 skip_keys = [] 1693 counter = defaultdict(lambda: 0) 1694 for sample in tqdm(self): 1695 for key, value in sample["input"].items(): 1696 key_name = key.split("---")[0] 1697 if key_name not in skip_keys: 1698 sums[key_name] += value[:, value.sum(0) != 0].sum(-1) 1699 counter[key_name] += torch.sum(value.sum(0) != 0) 1700 for key, value in sums.items(): 1701 stats[key]["mean"] = (value / counter[key]).unsqueeze(-1) 1702 sums = defaultdict(lambda: 0) 1703 for sample in tqdm(self): 1704 for key, value in sample["input"].items(): 1705 key_name = key.split("---")[0] 1706 if key_name not in skip_keys: 1707 sums[key_name] += ( 1708 (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2 1709 ).sum(-1) 1710 for key, value in sums.items(): 1711 stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key]) 1712 return stats
A generalized dataset class
Data and annotation are stored in separate InputStore and AnnotationStore objects; the dataset class manages their interactions.
44 def __init__( 45 self, 46 data_type: str, 47 annotation_type: str = "none", 48 ssl_transformations: List = None, 49 saved_data_path: str = None, 50 input_store: InputStore = None, 51 annotation_store: AnnotationStore = None, 52 only_load_annotated: bool = False, 53 recompute_annotation: bool = False, 54 # mask: str = None, 55 ids: List = None, 56 **data_parameters, 57 ) -> None: 58 """ 59 Parameters 60 ---------- 61 data_type : str 62 the data type (see available types by running BehaviorDataset.data_types()) 63 annotation_type : str 64 the annotation type (see available types by running BehaviorDataset.annotation_types()) 65 ssl_transformations : list 66 a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple 67 saved_data_path : str 68 the path to a pre-computed pickled dataset 69 input_store : InputStore 70 a pre-computed input store 71 annotation_store : AnnotationStore 72 a precomputed annotation store 73 only_load_annotated : bool 74 if True, the input files that don't have a matching annotation file will be disregarded 75 *data_parameters : dict 76 parameters to initialize the input and annotation stores 77 """ 78 79 mask = None 80 if len(data_parameters) == 0: 81 recompute_annotation = False 82 feature_extraction = data_parameters.get("feature_extraction") 83 if feature_extraction is not None and not issubclass( 84 options.input_stores[data_type], 85 options.feature_extractors[feature_extraction].input_store_class, 86 ): 87 raise ValueError( 88 f"The {feature_extraction} feature extractor does not work with " 89 f"the {data_type} data type, please choose a suclass of " 90 f"{options.feature_extractors[feature_extraction].input_store_class}" 91 ) 92 if ssl_transformations is None: 93 ssl_transformations = [] 94 self.ssl_transformations = ssl_transformations 95 self.input_type = data_type 96 self.annotation_type = annotation_type 97 self.stats = None 98 if mask is not None: 99 with open(mask, "rb") as f: 100 self.mask = pickle.load(f) 101 else: 102 self.mask = None 103 self.ids = ids 104 self.tag = None 105 self.return_unlabeled = None 106 # load saved key objects for annotation and input if they exist 107 input_key_objects, annotation_key_objects = None, None 108 if saved_data_path is not None: 109 if os.path.exists(saved_data_path): 110 with open(saved_data_path, "rb") as f: 111 input_key_objects, annotation_key_objects = pickle.load(f) 112 # if the input or the annotation store need to be created, generate the common video order 113 if len(data_parameters) > 0: 114 input_files = options.input_stores[data_type].get_file_ids( 115 **data_parameters 116 ) 117 annotation_files = options.annotation_stores[annotation_type].get_file_ids( 118 **data_parameters 119 ) 120 if only_load_annotated: 121 data_parameters["video_order"] = [ 122 x for x in input_files if x in annotation_files 123 ] 124 else: 125 data_parameters["video_order"] = input_files 126 if len(data_parameters["video_order"]) == 0: 127 raise RuntimeError( 128 "The length of file list is 0! Please check your data parameters!" 129 ) 130 data_parameters["mask"] = self.mask 131 # load or create the input store 132 ok = False 133 if input_store is not None: 134 self.input_store = input_store 135 ok = True 136 elif input_key_objects is not None: 137 try: 138 self.input_store = self._load_input_store(data_type, input_key_objects) 139 ok = True 140 except: 141 warnings.warn("Loading input store from key objects failed") 142 if not ok: 143 self.input_store = self._get_input_store( 144 data_type, deepcopy(data_parameters) 145 ) 146 # get the objects needed to create the annotation store (like a clip length dictionary) 147 annotation_objects = self.input_store.get_annotation_objects() 148 data_parameters.update(annotation_objects) 149 # load or create the annotation store 150 ok = False 151 if annotation_store is not None: 152 self.annotation_store = annotation_store 153 ok = True 154 elif ( 155 annotation_key_objects is not None 156 and mask is None 157 and not recompute_annotation 158 ): 159 try: 160 self.annotation_store = self._load_annotation_store( 161 annotation_type, annotation_key_objects 162 ) 163 ok = True 164 except: 165 warnings.warn("Loading annotation store from key objects failed") 166 if not ok: 167 self.annotation_store = self._get_annotation_store( 168 annotation_type, deepcopy(data_parameters) 169 ) 170 if ( 171 mask is None 172 and annotation_type != "none" 173 and not recompute_annotation 174 and ( 175 self.annotation_store.get_original_coordinates() 176 != self.input_store.get_original_coordinates() 177 ).any() 178 ): 179 raise RuntimeError( 180 "The clip orders in the annotation store and input store are different!" 181 ) 182 # filter the data based on data parameters 183 # print(f"1 {self.annotation_store.get_original_coordinates().shape=}") 184 # print(f"1 {self.input_store.get_original_coordinates().shape=}") 185 to_remove = self.annotation_store.filtered_indices() 186 if len(to_remove) > 0: 187 print( 188 f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples" 189 ) 190 if len(self.input_store) == len(self.annotation_store): 191 self.input_store.remove(to_remove) 192 self.annotation_store.remove(to_remove) 193 self.input_indices = list(range(len(self.input_store))) 194 self.annotation_indices = list(range(len(self.input_store))) 195 self.indices = list(range(len(self.input_store))) 196 # print(f'{data_parameters["video_order"]=}') 197 # print(f"{self.annotation_store.get_original_coordinates().shape=}") 198 # print(f"{self.input_store.get_original_coordinates().shape=}") 199 # count = 0 200 # for i, (x, y) in enumerate(zip( 201 # self.annotation_store.get_original_coordinates(), 202 # self.input_store.get_original_coordinates(), 203 # )): 204 # if (x != y).any(): 205 # count += 1 206 # print({i}) 207 # print(f"ann: {x}") 208 # print(f"inp: {y}") 209 # print("\n") 210 # if count > 50: 211 # break 212 if annotation_type != "none" and ( 213 self.annotation_store.get_original_coordinates().shape 214 != self.input_store.get_original_coordinates().shape 215 or ( 216 self.annotation_store.get_original_coordinates() 217 != self.input_store.get_original_coordinates() 218 ).any() 219 ): 220 raise RuntimeError( 221 "The clip orders in the annotation store and input store are different!" 222 )
Parameters
data_type : str the data type (see available types by running BehaviorDataset.data_types()) annotation_type : str the annotation type (see available types by running BehaviorDataset.annotation_types()) ssl_transformations : list a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple saved_data_path : str the path to a pre-computed pickled dataset input_store : InputStore a pre-computed input store annotation_store : AnnotationStore a precomputed annotation store only_load_annotated : bool if True, the input files that don't have a matching annotation file will be disregarded *data_parameters : dict parameters to initialize the input and annotation stores
267 def save(self, save_path: str) -> None: 268 """ 269 Save the dictionary 270 271 Parameters 272 ---------- 273 save_path : str 274 the path where the pickled file will be stored 275 """ 276 277 input_obj = self.input_store.key_objects() 278 annotation_obj = self.annotation_store.key_objects() 279 with open(save_path, "wb") as f: 280 pickle.dump((input_obj, annotation_obj), f)
Save the dictionary
Parameters
save_path : str the path where the pickled file will be stored
282 def to_ram(self) -> None: 283 """ 284 Transfer the dataset to RAM 285 """ 286 287 self.input_store.to_ram() 288 self.annotation_store.to_ram()
Transfer the dataset to RAM
290 def generate_full_length_gt(self) -> Dict: 291 if self.annotation_class() == "exclusive_classification": 292 gt = torch.zeros((len(self), self.len_segment())) 293 else: 294 gt = torch.zeros( 295 (len(self), len(self.behaviors_dict()), self.len_segment()) 296 ) 297 for i in range(len(self)): 298 gt[i] = self.annotation_store[i] 299 return self.generate_full_length_prediction(gt)
301 def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict: 302 """ 303 Map predictions for the equal-length pieces to predictions for the original data 304 305 Probabilities are averaged over predictions on overlapping intervals. 306 307 Parameters 308 ---------- 309 predicted: torch.Tensor 310 a tensor of predicted probabilities of shape `(N, #classes, #frames)` 311 312 Returns 313 ------- 314 full_length_prediction : dict 315 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are 316 averaged probability tensors 317 """ 318 319 result = defaultdict(lambda: {}) 320 counter = defaultdict(lambda: {}) 321 coordinates = self.input_store.get_original_coordinates() 322 for coords, prediction in zip(coordinates, predicted): 323 l = self.input_store.get_clip_length_from_coords(coords) 324 video_name = self.input_store.get_video_id(coords) 325 clip_id = self.input_store.get_clip_id(coords) 326 start, end = self.input_store.get_clip_start_end(coords) 327 if clip_id not in result[video_name].keys(): 328 result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 329 counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l) 330 result[video_name][clip_id][..., start:end] += ( 331 prediction.squeeze()[..., : end - start].detach().cpu() 332 ) 333 counter[video_name][clip_id][..., start:end] += 1 334 for video_name in result: 335 for clip_id in result[video_name]: 336 result[video_name][clip_id] /= counter[video_name][clip_id] 337 result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100 338 result = dict(result) 339 return result
Map predictions for the equal-length pieces to predictions for the original data
Probabilities are averaged over predictions on overlapping intervals.
Parameters
predicted: torch.Tensor
a tensor of predicted probabilities of shape (N, #classes, #frames)
Returns
full_length_prediction : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are averaged probability tensors
341 def find_valleys( 342 self, 343 predicted: Union[torch.Tensor, Dict], 344 threshold: float = 0.5, 345 min_frames: int = 0, 346 visibility_min_score: float = 0, 347 visibility_min_frac: float = 0, 348 main_class: int = 1, 349 low: bool = True, 350 predicted_error: torch.Tensor = None, 351 error_threshold: float = 0.5, 352 hysteresis: bool = False, 353 threshold_diff: float = None, 354 min_frames_error: int = None, 355 smooth_interval: int = 1, 356 cut_annotated: bool = False, 357 ) -> Dict: 358 """ 359 Find the intervals where the probability of a certain class is below or above a certain hard_threshold 360 361 Parameters 362 ---------- 363 predicted : torch.Tensor | dict 364 either a tensor of predictions for the data prompts or the output of 365 `BehaviorDataset.generate_full_length_prediction` 366 threshold : float, default 0.5 367 the main hard_threshold 368 min_frames : int, default 0 369 the minimum length of the intervals 370 visibility_min_score : float, default 0 371 the minimum visibility score in the intervals 372 visibility_min_frac : float, default 0 373 fraction of the interval that has to have the visibility score larger than visibility_score_thr 374 main_class : int, default 1 375 the index of the class the function is inspecting 376 low : bool, default True 377 if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above 378 predicted_error : torch.Tensor, optional 379 a tensor of error predictions for the data prompts 380 error_threshold : float, default 0.5 381 maximum possible probability of error at the intervals 382 hysteresis: bool, default False 383 if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff 384 threshold_diff: float, optional 385 the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes 386 min_frames_error: int, optional 387 if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames 388 389 Returns 390 ------- 391 valleys : dict 392 a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals 393 """ 394 395 result = defaultdict(lambda: []) 396 if type(predicted) is not dict: 397 predicted = self.generate_full_length_prediction(predicted) 398 if predicted_error is not None: 399 predicted_error = self.generate_full_length_prediction(predicted_error) 400 elif min_frames_error is not None and min_frames_error != 0: 401 # warnings.warn( 402 # f"The min_frames_error parameter is set to {min_frames_error} but no error prediction " 403 # f"is given! Setting min_frames_error to 0." 404 # ) 405 min_frames_error = 0 406 if low and hysteresis and threshold_diff is None: 407 raise ValueError( 408 "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff." 409 ) 410 if cut_annotated: 411 masked_intervals_dict = self.get_annotated_intervals() 412 else: 413 masked_intervals_dict = None 414 print("Valleys found:") 415 for v_id in predicted: 416 for clip_id in predicted[v_id].keys(): 417 if predicted_error is not None: 418 error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold 419 if min_frames_error is not None: 420 output, indices, counts = torch.unique_consecutive( 421 error_mask, return_inverse=True, return_counts=True 422 ) 423 wrong_indices = torch.where( 424 output * (counts < min_frames_error) 425 )[0] 426 if len(wrong_indices) > 0: 427 for i in wrong_indices: 428 error_mask[indices == i] = False 429 else: 430 error_mask = None 431 if masked_intervals_dict is not None: 432 masked_intervals = masked_intervals_dict[v_id][clip_id] 433 else: 434 masked_intervals = None 435 if not hysteresis: 436 res_indices_start, res_indices_end = apply_threshold( 437 predicted[v_id][clip_id][main_class, :], 438 threshold, 439 low, 440 error_mask, 441 min_frames, 442 smooth_interval, 443 masked_intervals, 444 ) 445 elif threshold_diff is not None: 446 if low: 447 soft_threshold = threshold + threshold_diff 448 else: 449 soft_threshold = threshold - threshold_diff 450 res_indices_start, res_indices_end = apply_threshold_hysteresis( 451 predicted[v_id][clip_id][main_class, :], 452 soft_threshold, 453 threshold, 454 low, 455 error_mask, 456 min_frames, 457 smooth_interval, 458 masked_intervals, 459 ) 460 else: 461 res_indices_start, res_indices_end = apply_threshold_max( 462 predicted[v_id][clip_id], 463 threshold, 464 main_class, 465 error_mask, 466 min_frames, 467 smooth_interval, 468 masked_intervals, 469 ) 470 start = self.input_store.get_clip_start(v_id, clip_id) 471 result[v_id] += [ 472 [i + start, j + start, clip_id] 473 for i, j in zip(res_indices_start, res_indices_end) 474 if self.input_store.get_visibility( 475 v_id, clip_id, i, j, visibility_min_score 476 ) 477 > visibility_min_frac 478 ] 479 result[v_id] = sorted(result[v_id]) 480 print(f" {v_id}: {len(result[v_id])}") 481 return dict(result)
Find the intervals where the probability of a certain class is below or above a certain hard_threshold
Parameters
predicted : torch.Tensor | dict
either a tensor of predictions for the data prompts or the output of
BehaviorDataset.generate_full_length_prediction
threshold : float, default 0.5
the main hard_threshold
min_frames : int, default 0
the minimum length of the intervals
visibility_min_score : float, default 0
the minimum visibility score in the intervals
visibility_min_frac : float, default 0
fraction of the interval that has to have the visibility score larger than visibility_score_thr
main_class : int, default 1
the index of the class the function is inspecting
low : bool, default True
if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above
predicted_error : torch.Tensor, optional
a tensor of error predictions for the data prompts
error_threshold : float, default 0.5
maximum possible probability of error at the intervals
hysteresis: bool, default False
if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff
threshold_diff: float, optional
the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes
min_frames_error: int, optional
if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames
Returns
valleys : dict a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals
483 def valleys_union(self, valleys_list) -> Dict: 484 """ 485 Find the intersection of two valleys dictionaries 486 487 Parameters 488 ---------- 489 valleys_list : list 490 a list of valleys dictionaries 491 492 Returns 493 ------- 494 intersection : dict 495 a new valleys dictionary with the intersection of the input intervals 496 """ 497 498 valleys_list = [x for x in valleys_list if x is not None] 499 if len(valleys_list) == 1: 500 return valleys_list[0] 501 elif len(valleys_list) == 0: 502 return {} 503 union = {} 504 keys_list = [set(valleys.keys()) for valleys in valleys_list] 505 keys = set.union(*keys_list) 506 for v_id in keys: 507 res = [] 508 clips_list = [ 509 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 510 ] 511 clips = set.union(*clips_list) 512 for clip_id in clips: 513 clip_intervals = [ 514 x 515 for valleys in valleys_list 516 for x in valleys[v_id] 517 if x[-1] == clip_id 518 ] 519 v_len = self.input_store.get_clip_length(v_id, clip_id) 520 arr = torch.zeros(v_len) 521 for start, end, _ in clip_intervals: 522 arr[start:end] += 1 523 output, indices, counts = torch.unique_consecutive( 524 arr > 0, return_inverse=True, return_counts=True 525 ) 526 long_indices = torch.where(output)[0] 527 res += [ 528 ( 529 (indices == i).nonzero(as_tuple=True)[0][0].item(), 530 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 531 clip_id, 532 ) 533 for i in long_indices 534 ] 535 union[v_id] = res 536 return union
Find the intersection of two valleys dictionaries
Parameters
valleys_list : list a list of valleys dictionaries
Returns
intersection : dict a new valleys dictionary with the intersection of the input intervals
538 def valleys_intersection(self, valleys_list) -> Dict: 539 """ 540 Find the intersection of two valleys dictionaries 541 542 Parameters 543 ---------- 544 valleys_list : list 545 a list of valleys dictionaries 546 547 Returns 548 ------- 549 intersection : dict 550 a new valleys dictionary with the intersection of the input intervals 551 """ 552 553 valleys_list = [x for x in valleys_list if x is not None] 554 if len(valleys_list) == 1: 555 return valleys_list[0] 556 elif len(valleys_list) == 0: 557 return {} 558 intersection = {} 559 keys_list = [set(valleys.keys()) for valleys in valleys_list] 560 keys = set.intersection(*keys_list) 561 for v_id in keys: 562 res = [] 563 clips_list = [ 564 set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list 565 ] 566 clips = set.intersection(*clips_list) 567 for clip_id in clips: 568 clip_intervals = [ 569 x 570 for valleys in valleys_list 571 for x in valleys[v_id] 572 if x[-1] == clip_id 573 ] 574 v_len = self.input_store.get_clip_length(v_id, clip_id) 575 arr = torch.zeros(v_len) 576 for start, end, _ in clip_intervals: 577 arr[start:end] += 1 578 output, indices, counts = torch.unique_consecutive( 579 arr, return_inverse=True, return_counts=True 580 ) 581 long_indices = torch.where(output == 2)[0] 582 res += [ 583 ( 584 (indices == i).nonzero(as_tuple=True)[0][0].item(), 585 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 586 clip_id, 587 ) 588 for i in long_indices 589 ] 590 intersection[v_id] = res 591 return intersection
Find the intersection of two valleys dictionaries
Parameters
valleys_list : list a list of valleys dictionaries
Returns
intersection : dict a new valleys dictionary with the intersection of the input intervals
593 def partition_train_test_val( 594 self, 595 use_test: float = 0, 596 split_path: str = None, 597 method: str = "random", 598 val_frac: float = 0, 599 test_frac: float = 0, 600 save_split: bool = False, 601 normalize: bool = False, 602 skip_normalization_keys: List = None, 603 stats: Dict = None, 604 ) -> Tuple: 605 """ 606 Partition the dataset into three new datasets 607 608 Parameters 609 ---------- 610 use_test : float, default 0 611 The fraction of the test dataset to be used in training without labels 612 split_path : str, optional 613 The path to load the split information from (if `'file'` method is used) and to save it to 614 (if `'save_split'` is `True`) 615 method : {'random', 'random:test-from-name', 'random:test-from-name:{name}', 616 'val-from-name:{val_name}:test-from-name:{test_name}', 617 'random:equalize:segments', 'random:equalize:videos', 618 'folders', 'time', 'time:strict', 'file'} 619 The partitioning method: 620 - `'random'`: sort videos into subsets randomly, 621 - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation 622 subsets randomly and create 623 the test subset from the video ids that start with a speific substring (`'test'` by default, or `name` 624 if provided), 625 - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but 626 making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain 627 occurences of the class get into the validation subset and `0.8 * test_frac` get into the test subset; 628 this in ensured for all classes in order of increasing number of occurences until the validation and 629 test subsets are full 630 - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test 631 subsets from the video ids that start with specific substrings (`val_name` for validation 632 and `test_name` for test) and sort all other videos into the training subset 633 - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets, 634 - `'time'`: split each video into training, validation and test subsequences, 635 - `'time:strict'`: split each video into validation, test and training subsequences 636 and throw out the last segments in validation and test (to get rid of overlaps), 637 - `'file'`: split according to a split file. 638 val_frac : float, default 0 639 The fraction of the dataset to be used in validation 640 test_frac : float, default 0 641 The fraction of the dataset to be used in test 642 save_split : bool, default False 643 Save a split file if True 644 645 Returns 646 ------- 647 train_dataset : BehaviorDataset 648 train dataset 649 650 val_dataset : BehaviorDataset 651 validation dataset 652 653 test_dataset : BehaviorDataset 654 test dataset 655 """ 656 657 train_indices, test_indices, val_indices = self._partition_indices( 658 split_path=split_path, 659 method=method, 660 val_frac=val_frac, 661 test_frac=test_frac, 662 save_split=save_split, 663 ) 664 val_dataset = self._create_new_dataset(val_indices) 665 test_dataset = self._create_new_dataset(test_indices) 666 train_dataset = self._create_new_dataset( 667 train_indices, ssl_indices=test_indices[: int(len(test_indices) * use_test)] 668 ) 669 print("Number of samples:") 670 print(f" validation:") 671 print(f" {val_dataset.count_classes()}") 672 print(f" training:") 673 print(f" {train_dataset.count_classes()}") 674 print(f" test:") 675 print(f" {test_dataset.count_classes()}") 676 if normalize: 677 if stats is None: 678 print("Computing normalization statistics...") 679 stats = train_dataset.get_normalization_stats(skip_normalization_keys) 680 else: 681 print("Setting loaded normalization statistics...") 682 train_dataset.set_normalization_stats(stats) 683 val_dataset.set_normalization_stats(stats) 684 test_dataset.set_normalization_stats(stats) 685 return train_dataset, test_dataset, val_dataset
Partition the dataset into three new datasets
Parameters
use_test : float, default 0
The fraction of the test dataset to be used in training without labels
split_path : str, optional
The path to load the split information from (if 'file'
method is used) and to save it to
(if 'save_split'
is True
)
method : {'random', 'random:test-from-name', 'random:test-from-name:{name}',
'val-from-name:{val_name}:test-from-name:{test_name}',
'random:equalize:segments', 'random:equalize:videos',
'folders', 'time', 'time:strict', 'file'}
The partitioning method:
- 'random'
: sort videos into subsets randomly,
- 'random:test-from-name'
(or 'random:test-from-name:{name}'
): sort videos into training and validation
subsets randomly and create
the test subset from the video ids that start with a speific substring ('test'
by default, or name
if provided),
- 'random:equalize:segments'
and 'random:equalize:videos'
: sort videos into subsets randomly but
making sure that for the rarest classes at least 0.8 * val_frac
of the videos/segments that contain
occurences of the class get into the validation subset and 0.8 * test_frac
get into the test subset;
this in ensured for all classes in order of increasing number of occurences until the validation and
test subsets are full
- 'val-from-name:{val_name}:test-from-name:{test_name}'
: create the validation and test
subsets from the video ids that start with specific substrings (val_name
for validation
and test_name
for test) and sort all other videos into the training subset
- 'folders'
: read videos from folders named test, train and val into corresponding subsets,
- 'time'
: split each video into training, validation and test subsequences,
- 'time:strict'
: split each video into validation, test and training subsequences
and throw out the last segments in validation and test (to get rid of overlaps),
- 'file'
: split according to a split file.
val_frac : float, default 0
The fraction of the dataset to be used in validation
test_frac : float, default 0
The fraction of the dataset to be used in test
save_split : bool, default False
Save a split file if True
Returns
train_dataset : BehaviorDataset train dataset
val_dataset : BehaviorDataset validation dataset
test_dataset : BehaviorDataset test dataset
687 def class_weights(self, proportional=False) -> List: 688 """ 689 Calculate class weights in inverse proportion to number of samples 690 Returns 691 ------- 692 weights: list 693 a list of class weights 694 """ 695 696 items = sorted( 697 [ 698 (k, v) 699 for k, v in self.annotation_store.count_classes().items() 700 if k != -100 701 ] 702 ) 703 if self.annotation_store.annotation_class() == "exclusive_classification": 704 if not proportional: 705 numerator = len(self.annotation_store) 706 else: 707 numerator = max([x[1] for x in items]) 708 weights = [numerator / (v + 1e-7) for _, v in items] 709 else: 710 items_zero = sorted( 711 [ 712 (k, v) 713 for k, v in self.annotation_store.count_classes(zeros=True).items() 714 if k != -100 715 ] 716 ) 717 if not proportional: 718 numerators = defaultdict(lambda: len(self.annotation_store)) 719 else: 720 numerators = { 721 item_one[0]: max(item_one[1], item_zero[1]) 722 for item_one, item_zero in zip(items, items_zero) 723 } 724 weights = {} 725 weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero] 726 weights[1] = [numerators[k] / (v + 1e-7) for k, v in items] 727 return weights
Calculate class weights in inverse proportion to number of samples
Returns
weights: list a list of class weights
729 def boundary_class_weight(self): 730 if self.annotation_type != "none": 731 f = self.annotation_store.data.flatten() 732 _, inv = torch.unique_consecutive(f, return_inverse=True) 733 boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape( 734 self.annotation_store.data.shape 735 ) 736 boundary[..., 0] = 0 737 cnt = Counter(boundary.flatten().numpy()) 738 return cnt[1] / cnt[0] 739 else: 740 return 0
742 def count_classes(self, bouts: bool = False) -> Dict: 743 """ 744 Get a class counter dictionary 745 746 Parameters 747 ---------- 748 bouts : bool, default False 749 if `True`, instead of frame counts segment counts are returned 750 751 Returns 752 ------- 753 count_dictionary : dict 754 a dictionary with class indices as keys and frame or bout counts as values 755 """ 756 757 return self.annotation_store.count_classes(bouts=bouts)
Get a class counter dictionary
Parameters
bouts : bool, default False
if True
, instead of frame counts segment counts are returned
Returns
count_dictionary : dict a dictionary with class indices as keys and frame or bout counts as values
759 def behaviors_dict(self) -> Dict: 760 """ 761 Get a behavior dictionary 762 763 Returns 764 ------- 765 dict 766 behavior dictionary 767 """ 768 769 return self.annotation_store.behaviors_dict()
Get a behavior dictionary
Returns
dict behavior dictionary
779 def features_shape(self) -> Dict: 780 """ 781 Get the shapes of the input features 782 783 Returns 784 ------- 785 shapes : Dict 786 a dictionary with the shapes of the features 787 """ 788 789 sample = self.input_store[0] 790 shapes = {k: v.shape for k, v in sample.items()} 791 # for key, value in shapes.items(): 792 # print(f'{key}: {value}') 793 return shapes
Get the shapes of the input features
Returns
shapes : Dict a dictionary with the shapes of the features
795 def num_classes(self) -> int: 796 """ 797 Get the number of classes in the data 798 799 Returns 800 ------- 801 num_classes : int 802 the number of classes 803 """ 804 805 return len(self.annotation_store.behaviors_dict())
Get the number of classes in the data
Returns
num_classes : int the number of classes
807 def len_segment(self) -> int: 808 """ 809 Get the segment length in the data 810 811 Returns 812 ------- 813 len_segment : int 814 the segment length 815 """ 816 817 sample = self.input_store[0] 818 key = list(sample.keys())[0] 819 return sample[key].shape[-1]
Get the segment length in the data
Returns
len_segment : int the segment length
821 def set_ssl_transformations(self, ssl_transformations: List) -> None: 822 """ 823 Set new SSL transformations 824 825 Parameters 826 ---------- 827 ssl_transformations : list 828 a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets 829 lists 830 """ 831 832 self.ssl_transformations = ssl_transformations
Set new SSL transformations
Parameters
ssl_transformations : list a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets lists
834 @classmethod 835 def new(cls, *args, **kwargs): 836 """ 837 Create a new object of the same class 838 839 Returns 840 ------- 841 new_instance: BehaviorDataset 842 a new instance of the same class 843 """ 844 845 return cls(*args, **kwargs)
Create a new object of the same class
Returns
new_instance: BehaviorDataset a new instance of the same class
847 @classmethod 848 def get_parameters(cls, data_type: str, annotation_type: str) -> List: 849 """ 850 Get parameters necessary for initialization 851 852 Parameters 853 ---------- 854 data_type : str 855 the data type 856 annotation_type : str 857 the annotation type 858 """ 859 860 input_features = options.input_stores[data_type].get_parameters() 861 annotation_features = options.annotation_stores[ 862 annotation_type 863 ].get_parameters() 864 self_features = inspect.getfullargspec(cls.__init__).args 865 return self_features + input_features + annotation_features
Get parameters necessary for initialization
Parameters
data_type : str the data type annotation_type : str the annotation type
867 @staticmethod 868 def data_types() -> List: 869 """ 870 List available data types 871 872 Returns 873 ------- 874 data_types : list 875 available data types 876 """ 877 878 return list(options.input_stores.keys())
List available data types
Returns
data_types : list available data types
880 @staticmethod 881 def annotation_types() -> List: 882 """ 883 List available annotation types 884 885 Returns 886 ------- 887 annotation_types : list 888 available annotation types 889 """ 890 891 return list(options.annotation_stores.keys())
List available annotation types
Returns
annotation_types : list available annotation types
965 def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None: 966 """ 967 Set the parameters that change the subset that is returned at `__getitem__` 968 969 Parameters 970 ---------- 971 unlabeled : bool 972 a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and 973 all if `None` 974 tag : int 975 if not `None`, only samples with this meta tag will be returned 976 """ 977 978 if unlabeled != self.return_unlabeled: 979 self.annotation_indices = self.annotation_store.get_indices(unlabeled) 980 self.return_unlabeled = unlabeled 981 if tag != self.tag: 982 self.input_indices = self.input_store.get_indices(tag) 983 self.tag = tag 984 self.indices = [x for x in self.annotation_indices if x in self.input_indices]
Set the parameters that change the subset that is returned at __getitem__
Parameters
unlabeled : bool
a pseudolabeling parameter; return only unlabeled samples if True
, only labeled if False
and
all if None
tag : int
if not None
, only samples with this meta tag will be returned
1410 def get_intervals(self) -> Tuple[dict, Optional[list]]: 1411 """ 1412 Get a list of intervals covered by the dataset in the original coordinates 1413 1414 Returns 1415 ------- 1416 intervals : dict 1417 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1418 values are lists of the intervals in `[start, end]` format 1419 """ 1420 1421 counter = defaultdict(lambda: {}) 1422 coordinates = self.input_store.get_original_coordinates() 1423 for coords in coordinates: 1424 l = self.input_store.get_clip_length_from_coords(coords) 1425 video_name = self.input_store.get_video_id(coords) 1426 clip_id = self.input_store.get_clip_id(coords) 1427 start, end = self.input_store.get_clip_start_end(coords) 1428 if clip_id not in counter[video_name]: 1429 counter[video_name][clip_id] = np.zeros(l) 1430 counter[video_name][clip_id][start:end] = 1 1431 result = {video_name: {} for video_name in counter} 1432 for video_name in counter: 1433 for clip_id in counter[video_name]: 1434 result[video_name][clip_id] = self._get_intervals_from_ind( 1435 np.where(counter[video_name][clip_id])[0] 1436 ) 1437 return result, self.ids
Get a list of intervals covered by the dataset in the original coordinates
Returns
intervals : dict
a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
values are lists of the intervals in [start, end]
format
1439 def get_unannotated_intervals(self, first_intervals=None) -> Dict: 1440 """ 1441 Get a list of intervals in the original coordinates where there is no annotation 1442 1443 Returns 1444 ------- 1445 intervals : dict 1446 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1447 values are lists of the intervals in `[start, end]` format 1448 """ 1449 1450 counter_value = 2 1451 if first_intervals is None: 1452 first_intervals = defaultdict(lambda: defaultdict(lambda: [])) 1453 counter_value = 1 1454 counter = defaultdict(lambda: {}) 1455 coordinates = self.input_store.get_original_coordinates() 1456 for i, coords in enumerate(coordinates): 1457 l = self.input_store.get_clip_length_from_coords(coords) 1458 ann = self.annotation_store[i] 1459 if ( 1460 self.annotation_store.annotation_class() 1461 == "nonexclusive_classification" 1462 ): 1463 ann = ann[0, :] 1464 video_name = self.input_store.get_video_id(coords) 1465 clip_id = self.input_store.get_clip_id(coords) 1466 start, end = self.input_store.get_clip_start_end(coords) 1467 if clip_id not in counter[video_name]: 1468 counter[video_name][clip_id] = np.ones(l) 1469 counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int() 1470 result = {video_name: {} for video_name in counter} 1471 for video_name in counter: 1472 for clip_id in counter[video_name]: 1473 for start, end in first_intervals[video_name][clip_id]: 1474 counter[video_name][clip_id][start:end] += 1 1475 result[video_name][clip_id] = self._get_intervals_from_ind( 1476 np.where(counter[video_name][clip_id] == counter_value)[0] 1477 ) 1478 return result
Get a list of intervals in the original coordinates where there is no annotation
Returns
intervals : dict
a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
values are lists of the intervals in [start, end]
format
1480 def get_annotated_intervals(self) -> Dict: 1481 """ 1482 Get a list of intervals in the original coordinates where there is no annotation 1483 1484 Returns 1485 ------- 1486 intervals : dict 1487 a nested dictionary where first-level keys are video ids, second-level keys are clip ids and 1488 values are lists of the intervals in `[start, end]` format 1489 """ 1490 1491 if self.annotation_type == "none": 1492 return [] 1493 counter_value = 1 1494 counter = defaultdict(lambda: {}) 1495 coordinates = self.input_store.get_original_coordinates() 1496 for i, coords in enumerate(coordinates): 1497 l = self.input_store.get_clip_length_from_coords(coords) 1498 ann = self.annotation_store[i] 1499 video_name = self.input_store.get_video_id(coords) 1500 clip_id = self.input_store.get_clip_id(coords) 1501 start, end = self.input_store.get_clip_start_end(coords) 1502 if clip_id not in counter[video_name]: 1503 counter[video_name][clip_id] = np.zeros(l) 1504 if ( 1505 self.annotation_store.annotation_class() 1506 == "nonexclusive_classification" 1507 ): 1508 counter[video_name][clip_id][start:end] = ( 1509 torch.sum(ann[:, : end - start] != -100, dim=0) > 0 1510 ).int() 1511 else: 1512 counter[video_name][clip_id][start:end] = ( 1513 ann[: end - start] != -100 1514 ).int() 1515 result = {video_name: {} for video_name in counter} 1516 for video_name in counter: 1517 for clip_id in counter[video_name]: 1518 result[video_name][clip_id] = self._get_intervals_from_ind( 1519 np.where(counter[video_name][clip_id] == counter_value)[0] 1520 ) 1521 return result
Get a list of intervals in the original coordinates where there is no annotation
Returns
intervals : dict
a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
values are lists of the intervals in [start, end]
format
1523 def get_ids(self) -> Dict: 1524 """ 1525 Get a dictionary of all clip ids in the dataset 1526 1527 Returns 1528 ------- 1529 ids : dict 1530 a dictionary where keys are video ids and values are lists of clip ids 1531 """ 1532 1533 coordinates = self.input_store.get_original_coordinates() 1534 video_ids = np.array(self.input_store.get_video_id_order()) 1535 id_set = set(video_ids) 1536 result = {} 1537 for video_id in id_set: 1538 coords = coordinates[video_ids == video_id] 1539 clip_ids = list({self.input_store.get_clip_id(c) for c in coords}) 1540 result[video_id] = clip_ids 1541 return result
Get a dictionary of all clip ids in the dataset
Returns
ids : dict a dictionary where keys are video ids and values are lists of clip ids
1543 def get_len(self, video_id: str, clip_id: str) -> int: 1544 """ 1545 Get the length of a specific clip 1546 1547 Parameters 1548 ---------- 1549 video_id : str 1550 the video id 1551 clip_id : str 1552 the clip id 1553 1554 Returns 1555 ------- 1556 length : int 1557 the length 1558 """ 1559 1560 return self.input_store.get_clip_length(video_id, clip_id)
Get the length of a specific clip
Parameters
video_id : str the video id clip_id : str the clip id
Returns
length : int the length
1562 def get_confusion_matrix( 1563 self, prediction: torch.Tensor, confusion_type: str = "recall" 1564 ) -> Tuple[ndarray, list]: 1565 """ 1566 Get a confusion matrix 1567 1568 Parameters 1569 ---------- 1570 prediction : torch.Tensor 1571 a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)` 1572 confusion_type : {"recall", "precision"} 1573 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 1574 into account, and if `type` is `"precision"`, only false negatives 1575 1576 Returns 1577 ------- 1578 confusion_matrix : np.ndarray 1579 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 1580 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 1581 `N_i` is the number of frames that have the i-th label in the ground truth 1582 classes : list 1583 a list of classes 1584 """ 1585 1586 behaviors_dict = self.annotation_store.behaviors_dict() 1587 num_behaviors = len(behaviors_dict) 1588 confusion_matrix = np.zeros((num_behaviors, num_behaviors)) 1589 if self.annotation_store.annotation_class() == "exclusive_classification": 1590 exclusive = True 1591 confusion_type = None 1592 elif self.annotation_store.annotation_class() == "nonexclusive_classification": 1593 exclusive = False 1594 else: 1595 raise RuntimeError( 1596 f"The {self.annotation_store.annotation_class()} annotation class is not recognized!" 1597 ) 1598 for ann, p in zip(self.annotation_store, prediction): 1599 if exclusive: 1600 class_prediction = torch.max(p, dim=0)[1] 1601 for i in behaviors_dict.keys(): 1602 for j in behaviors_dict.keys(): 1603 confusion_matrix[i, j] += int( 1604 torch.sum(class_prediction[ann == i] == j) 1605 ) 1606 else: 1607 class_prediction = (p > 0.5).int() 1608 for i in behaviors_dict.keys(): 1609 for j in behaviors_dict.keys(): 1610 if confusion_type == "recall": 1611 pred = deepcopy(class_prediction[j]) 1612 if i != j: 1613 pred[ann[j] == 1] = 0 1614 confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1])) 1615 elif confusion_type == "precision": 1616 annotation = deepcopy(ann[j]) 1617 if i != j: 1618 annotation[class_prediction[j] == 1] = 0 1619 confusion_matrix[i, j] += int( 1620 torch.sum(annotation[class_prediction[i] == 1]) 1621 ) 1622 else: 1623 raise ValueError( 1624 f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']" 1625 ) 1626 counter = self.annotation_store.count_classes() 1627 for i in behaviors_dict.keys(): 1628 if counter[i] != 0: 1629 if confusion_type == "recall" or confusion_type is None: 1630 confusion_matrix[i, :] /= counter[i] 1631 else: 1632 confusion_matrix[:, i] /= counter[i] 1633 return confusion_matrix, list(behaviors_dict.values()), confusion_type
Get a confusion matrix
Parameters
prediction : torch.Tensor
a tensor of predicted class probabilities of shape (#samples, #classes, #frames)
confusion_type : {"recall", "precision"}
for datasets with non-exclusive annotation, if type
is "recall"
, only false positives are taken
into account, and if type
is "precision"
, only false negatives
Returns
confusion_matrix : np.ndarray
a confusion matrix of shape (#classes, #classes)
where A[i, j] = F_ij/N_i
, F_ij
is the number of
frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
N_i
is the number of frames that have the i-th label in the ground truth
classes : list
a list of classes
1635 def annotation_class(self) -> str: 1636 """ 1637 Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon) 1638 1639 Returns 1640 ------- 1641 annotation_class : str 1642 the type of annotation 1643 """ 1644 1645 return self.annotation_store.annotation_class()
Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon)
Returns
annotation_class : str the type of annotation
1647 def set_normalization_stats(self, stats: Dict) -> None: 1648 """ 1649 Set the stats to normalize data at runtime 1650 1651 Parameters 1652 ---------- 1653 stats : dict 1654 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1655 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1656 """ 1657 1658 self.stats = stats
Set the stats to normalize data at runtime
Parameters
stats : dict
a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
and values are the statistics in torch
tensors of shape (#features, 1)
1660 def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]: 1661 coords = self.input_store.get_original_coordinates() 1662 clips = set( 1663 [ 1664 self.input_store.get_clip_id(c) 1665 for c in coords 1666 if self.input_store.get_video_id(c) == video_id 1667 ] 1668 ) 1669 min_frames = {} 1670 max_frames = {} 1671 for clip in clips: 1672 start = self.input_store.get_clip_start(video_id, clip) 1673 end = start + self.input_store.get_clip_length(video_id, clip) 1674 min_frames[clip] = start 1675 max_frames[clip] = end - 1 1676 return min_frames, max_frames
1678 def get_normalization_stats(self, skip_keys=None) -> Dict: 1679 """ 1680 Get mean and standard deviation for each key 1681 1682 Returns 1683 ------- 1684 stats : dict 1685 a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' 1686 and values are the statistics in `torch` tensors of shape `(#features, 1)` 1687 """ 1688 1689 stats = defaultdict(lambda: {}) 1690 sums = defaultdict(lambda: 0) 1691 if skip_keys is None: 1692 skip_keys = [] 1693 counter = defaultdict(lambda: 0) 1694 for sample in tqdm(self): 1695 for key, value in sample["input"].items(): 1696 key_name = key.split("---")[0] 1697 if key_name not in skip_keys: 1698 sums[key_name] += value[:, value.sum(0) != 0].sum(-1) 1699 counter[key_name] += torch.sum(value.sum(0) != 0) 1700 for key, value in sums.items(): 1701 stats[key]["mean"] = (value / counter[key]).unsqueeze(-1) 1702 sums = defaultdict(lambda: 0) 1703 for sample in tqdm(self): 1704 for key, value in sample["input"].items(): 1705 key_name = key.split("---")[0] 1706 if key_name not in skip_keys: 1707 sums[key_name] += ( 1708 (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2 1709 ).sum(-1) 1710 for key, value in sums.items(): 1711 stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key]) 1712 return stats
Get mean and standard deviation for each key
Returns
stats : dict
a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
and values are the statistics in torch
tensors of shape (#features, 1)