dlc2action.project.project      
                        Project interface
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7""" 8Project interface 9""" 10 11import gc 12import os 13import pickle 14import shutil 15import time 16import warnings 17from abc import abstractmethod 18from collections import defaultdict 19from collections.abc import Iterable, Mapping 20from copy import deepcopy 21from email.policy import default 22from itertools import product 23from pathlib import Path 24from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union 25 26import cv2 27import numpy as np 28import optuna 29import pandas as pd 30import plotly 31import torch 32from matplotlib import cm 33from matplotlib import pyplot as plt 34from matplotlib import rc 35from numpy import ndarray 36from ruamel.yaml import YAML 37from ruamel.yaml.comments import CommentedMap, CommentedSet 38from tqdm import tqdm 39 40from dlc2action import __version__, options 41from dlc2action.data.dataset import BehaviorDataset 42from dlc2action.project.meta import ( 43 DecisionThresholds, 44 Run, 45 SavedRuns, 46 SavedStores, 47 Searches, 48 Suggestions, 49) 50from dlc2action.task.task_dispatcher import TaskDispatcher 51from dlc2action.utils import apply_threshold, binarize_data, load_pickle 52 53 54class Project: 55 """A class to create and maintain the project files + keep track of experiments.""" 56 57 def __init__( 58 self, 59 name: str, 60 data_type: str = None, 61 annotation_type: str = "none", 62 projects_path: str = None, 63 data_path: Union[str, List] = None, 64 annotation_path: Union[str, List] = None, 65 copy: bool = False, 66 ) -> None: 67 """Initialize the class. 68 69 Parameters 70 ---------- 71 name : str 72 name of the project 73 data_type : str, optional 74 data type (run Project.data_types() to see available options; has to be provided if the project is being 75 created) 76 annotation_type : str, default 'none' 77 annotation type (run Project.annotation_types() to see available options) 78 projects_path : str, optional 79 path to the projects folder (is filled with ~/DLC2Action by default) 80 data_path : str, optional 81 path to the folder containing input files for the project (has to be provided if the project is being 82 created) 83 annotation_path : str, optional 84 path to the folder containing annotation files for the project 85 copy : bool, default False 86 if True, the files from annotation_path and data_path will be copied to the projects folder; 87 otherwise they will be moved 88 89 """ 90 if projects_path is None: 91 projects_path = os.path.join(str(Path.home()), "DLC2Action") 92 if not os.path.exists(projects_path): 93 os.mkdir(projects_path) 94 self.project_path = os.path.join(projects_path, name) 95 self.name = name 96 self.data_type = data_type 97 self.annotation_type = annotation_type 98 self.data_path = data_path 99 self.annotation_path = annotation_path 100 if not os.path.exists(self.project_path): 101 if data_type is None: 102 raise ValueError( 103 "The data_type parameter is necessary when creating a new project!" 104 ) 105 self._initialize_project( 106 data_type, annotation_type, data_path, annotation_path, copy 107 ) 108 else: 109 self.annotation_type, self.data_type = self._read_types() 110 if data_type != self.data_type and data_type is not None: 111 raise ValueError( 112 f"The project has already been initialized with data_type={self.data_type}!" 113 ) 114 if annotation_type != self.annotation_type and annotation_type != "none": 115 raise ValueError( 116 f"The project has already been initialized with annotation_type={self.annotation_type}!" 117 ) 118 self.annotation_path, data_path = self._read_paths() 119 if self.data_path is None: 120 self.data_path = data_path 121 # if data_path != self.data_path and data_path is not None: 122 # raise ValueError( 123 # f"The project has already been initialized with data_path={self.data_path}!" 124 # ) 125 if annotation_path != self.annotation_path and annotation_path is not None: 126 raise ValueError( 127 f"The project has already been initialized with annotation_path={self.annotation_path}!" 128 ) 129 self._update_configs() 130 131 def _make_prediction( 132 self, 133 prediction_name: str, 134 episode_names: List, 135 load_epochs: Union[List[int], int] = None, 136 parameters_update: Dict = None, 137 data_path: str = None, 138 file_paths: Set = None, 139 mode: str = "all", 140 augment_n: int = 0, 141 evaluate: bool = False, 142 task: TaskDispatcher = None, 143 embedding: bool = False, 144 annotation_type: str = "none", 145 ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]: 146 """Generate a prediction. 147 Parameters 148 ---------- 149 prediction_name : str 150 name of the prediction 151 episode_names : List 152 names of the episodes to use for the prediction 153 load_epochs : Union[List[int],int], optional 154 epochs to load for each episode; if a single integer is provided, it will be used for all episodes; 155 if None, the last epochs will be used 156 parameters_update : Dict, optional 157 dictionary with parameters to update the task parameters 158 data_path : str, optional 159 path to the data folder; if None, the data_path from the project will be used 160 file_paths : Set, optional 161 set of file paths to use for the prediction; if None, the data_path will be used 162 mode : str, default "all 163 mode of the prediction; can be "train", "val", "test" or "all 164 augment_n : int, default 0 165 number of augmentations to apply to the data; if 0, no augmentations are applied 166 evaluate : bool, default False 167 if True, the prediction will be evaluated and the results will be saved to the episode meta file 168 task : TaskDispatcher, optional 169 task object to use for the prediction; if None, a new task object will be created 170 embedding : bool, default False 171 if True, the prediction will be returned as an embedding 172 annotation_type : str, default "none 173 type of the annotation to use for the prediction; if "none", the annotation will not be used 174 Returns 175 ------- 176 task : TaskDispatcher 177 task object used for the prediction 178 parameters : Dict 179 parameters used for the prediction 180 mode : str 181 mode of the prediction 182 prediction : torch.Tensor 183 prediction tensor of shape (num_videos, num_behaviors, num_frames) 184 inference_time : str 185 time taken for the prediction in the format "HH:MM:SS" 186 behavior_dict : Dict 187 dictionary with behavior names and their indices 188 """ 189 190 names = [] 191 for episode_name in episode_names: 192 names += self._episodes().get_runs(episode_name) 193 if len(names) == 0: 194 warnings.warn(f"None of the episodes {episode_names} exist!") 195 names = [None] 196 if load_epochs is None: 197 load_epochs = [None for _ in names] 198 elif isinstance(load_epochs, int): 199 load_epochs = [load_epochs for _ in names] 200 assert len(load_epochs) == len( 201 names 202 ), f"Length of load_epochs ({len(load_epochs)}) must match the number of episodes ({len(names)})!" 203 prediction = None 204 decision_thresholds = None 205 time_total = 0 206 behavior_dicts = [ 207 self.get_behavior_dictionary(episode_name) for episode_name in names 208 ] 209 210 if not all( 211 [ 212 set(d.values()) == set(behavior_dicts[0].values()) 213 for d in behavior_dicts[1:] 214 ] 215 ): 216 raise ValueError( 217 f"Episodes {episode_names} have different sets of behaviors!" 218 ) 219 behaviors = list(behavior_dicts[0].values()) 220 221 for episode_name, load_epoch, behavior_dict in zip( 222 names, load_epochs, behavior_dicts 223 ): 224 print(f"episode {episode_name}") 225 task, parameters, data_mode = self._make_task_prediction( 226 prediction_name=prediction_name, 227 load_episode=episode_name, 228 parameters_update=parameters_update, 229 load_epoch=load_epoch, 230 data_path=data_path, 231 mode=mode, 232 file_paths=file_paths, 233 task=task, 234 decision_thresholds=decision_thresholds, 235 annotation_type=annotation_type, 236 ) 237 # data_mode = "train" if mode == "all" else mode 238 time_start = time.time() 239 new_pred = task.predict( 240 data_mode, 241 raw_output=True, 242 apply_primary_function=True, 243 augment_n=augment_n, 244 embedding=embedding, 245 ) 246 indices = [ 247 behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1]) 248 ] 249 new_pred = new_pred[:, indices, :] 250 time_end = time.time() 251 time_total += time_end - time_start 252 if evaluate: 253 _, metrics = task.evaluate_prediction( 254 new_pred, data=data_mode, indices=indices 255 ) 256 if mode == "val": 257 self._update_episode_metrics(episode_name, metrics) 258 if prediction is None: 259 prediction = new_pred 260 else: 261 prediction += new_pred 262 print("\n") 263 hours = int(time_total // 3600) 264 time_total -= hours * 3600 265 minutes = int(time_total // 60) 266 time_total -= minutes * 60 267 seconds = int(time_total) 268 inference_time = f"{hours}:{minutes:02}:{seconds:02}" 269 prediction /= len(names) 270 return ( 271 task, 272 parameters, 273 data_mode, 274 prediction, 275 inference_time, 276 behavior_dicts[0], 277 ) 278 279 def _make_task_prediction( 280 self, 281 prediction_name: str, 282 load_episode: str = None, 283 parameters_update: Dict = None, 284 load_epoch: int = None, 285 data_path: str = None, 286 annotation_path: str = None, 287 mode: str = "val", 288 file_paths: Set = None, 289 decision_thresholds: List = None, 290 task: TaskDispatcher = None, 291 annotation_type: str = "none", 292 ) -> Tuple[TaskDispatcher, Dict, str]: 293 """Make a `TaskDispatcher` object that will be used to generate a prediction.""" 294 if parameters_update is None: 295 parameters_update = {} 296 parameters_update_second = {} 297 if mode == "all" or data_path is not None or file_paths is not None: 298 parameters_update_second["training"] = { 299 "val_frac": 0, 300 "test_frac": 0, 301 "partition_method": "random", 302 "save_split": False, 303 "split_path": None, 304 } 305 mode = "train" 306 if decision_thresholds is not None: 307 if ( 308 len(decision_thresholds) 309 == self._episode(load_episode).get_num_classes() 310 ): 311 parameters_update_second["general"] = { 312 "threshold_value": decision_thresholds 313 } 314 else: 315 raise ValueError( 316 f"The length of the decision thresholds {decision_thresholds} " 317 f"must be equal to the length of the behaviors dictionary " 318 f"{self._episode(load_episode).get_behaviors_dict()}" 319 ) 320 data_param_update = {} 321 if data_path is not None: 322 data_param_update = {"data_path": data_path} 323 if annotation_path is None: 324 data_param_update["annotation_path"] = data_path 325 if annotation_path is not None: 326 data_param_update["annotation_path"] = annotation_path 327 if file_paths is not None: 328 data_param_update = {"data_path": None, "file_paths": file_paths} 329 parameters_update = self._update(parameters_update, {"data": data_param_update}) 330 if data_path is not None or file_paths is not None: 331 general_update = { 332 "annotation_type": annotation_type, 333 "only_load_annotated": False, 334 } 335 else: 336 general_update = {} 337 parameters_update = self._update(parameters_update, {"general": general_update}) 338 task, parameters = self._make_task( 339 episode_name=prediction_name, 340 load_episode=load_episode, 341 parameters_update=parameters_update, 342 parameters_update_second=parameters_update_second, 343 load_epoch=load_epoch, 344 purpose="prediction", 345 task=task, 346 ) 347 return task, parameters, mode 348 349 def _make_task_training( 350 self, 351 episode_name: str, 352 load_episode: str = None, 353 parameters_update: Dict = None, 354 load_epoch: int = None, 355 load_search: str = None, 356 load_parameters: list = None, 357 round_to_binary: list = None, 358 load_strict: bool = True, 359 continuing: bool = False, 360 task: TaskDispatcher = None, 361 mask_name: str = None, 362 throwaway: bool = False, 363 ) -> Tuple[TaskDispatcher, Dict, str]: 364 """Make a `TaskDispatcher` object that will be used to generate a prediction.""" 365 if parameters_update is None: 366 parameters_update = {} 367 if continuing: 368 purpose = "continuing" 369 else: 370 purpose = "training" 371 if mask_name is not None: 372 mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle") 373 parameters_update_second = {"data": {"real_lens": mask_name}} 374 if throwaway: 375 parameters_update = self._update( 376 parameters_update, {"training": {"normalize": False, "device": "cpu"}} 377 ) 378 return self._make_task( 379 episode_name, 380 load_episode, 381 parameters_update, 382 parameters_update_second, 383 load_epoch, 384 load_search, 385 load_parameters, 386 round_to_binary, 387 purpose, 388 task, 389 load_strict=load_strict, 390 ) 391 392 def _make_parameters( 393 self, 394 episode_name: str, 395 load_episode: str = None, 396 parameters_update: Dict = None, 397 parameters_update_second: Dict = None, 398 load_epoch: int = None, 399 load_search: str = None, 400 load_parameters: list = None, 401 round_to_binary: list = None, 402 purpose: str = "train", 403 load_strict: bool = True, 404 ): 405 """Construct a parameters dictionary.""" 406 if parameters_update is None: 407 parameters_update = {} 408 pars_update = deepcopy(parameters_update) 409 if parameters_update_second is None: 410 parameters_update_second = {} 411 if ( 412 purpose == "prediction" 413 and "model" in pars_update.keys() 414 and pars_update["general"]["model_name"] != "motionbert" 415 ): 416 raise ValueError("Cannot change model parameters after training!") 417 if purpose in ["continuing", "prediction"] and load_episode is not None: 418 read_parameters = self._read_parameters() 419 parameters = self._episodes().load_parameters(load_episode) 420 parameters["metrics"] = self._update( 421 read_parameters["metrics"], parameters["metrics"] 422 ) 423 parameters["ssl"] = self._update( 424 read_parameters["ssl"], parameters.get("ssl", {}) 425 ) 426 else: 427 parameters = self._read_parameters() 428 if "model" in pars_update: 429 model_params = pars_update.pop("model") 430 else: 431 model_params = None 432 if "features" in pars_update: 433 feat_params = pars_update.pop("features") 434 else: 435 feat_params = None 436 if "augmentations" in pars_update: 437 aug_params = pars_update.pop("augmentations") 438 else: 439 aug_params = None 440 parameters = self._update(parameters, pars_update) 441 if pars_update.get("general", {}).get("model_name") is not None: 442 model_name = parameters["general"]["model_name"] 443 parameters["model"] = self._open_yaml( 444 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 445 ) 446 if pars_update.get("general", {}).get("feature_extraction") is not None: 447 feat_name = parameters["general"]["feature_extraction"] 448 parameters["features"] = self._open_yaml( 449 os.path.join( 450 self.project_path, "config", "features", f"{feat_name}.yaml" 451 ) 452 ) 453 aug_name = options.extractor_to_transformer[ 454 parameters["general"]["feature_extraction"] 455 ] 456 parameters["augmentations"] = self._open_yaml( 457 os.path.join( 458 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 459 ) 460 ) 461 if model_params is not None: 462 parameters["model"] = self._update(parameters["model"], model_params) 463 if feat_params is not None: 464 parameters["features"] = self._update(parameters["features"], feat_params) 465 if aug_params is not None: 466 parameters["augmentations"] = self._update( 467 parameters["augmentations"], aug_params 468 ) 469 if load_search is not None: 470 parameters = self._update_with_search( 471 parameters, load_search, load_parameters, round_to_binary 472 ) 473 parameters = self._fill( 474 parameters, 475 episode_name, 476 load_episode, 477 load_epoch=load_epoch, 478 load_strict=load_strict, 479 only_load_model=(purpose != "continuing"), 480 continuing=(purpose in ["prediction", "continuing"]), 481 enforce_split_parameters=(purpose == "prediction"), 482 ) 483 parameters = self._update(parameters, parameters_update_second) 484 return parameters 485 486 def _make_task( 487 self, 488 episode_name: str, 489 load_episode: str = None, 490 parameters_update: Dict = None, 491 parameters_update_second: Dict = None, 492 load_epoch: int = None, 493 load_search: str = None, 494 load_parameters: list = None, 495 round_to_binary: list = None, 496 purpose: str = "train", 497 task: TaskDispatcher = None, 498 load_strict: bool = True, 499 ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]: 500 """Make a `TaskDispatcher` object. 501 502 The task parameters are read from the config files and then updated with the 503 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 504 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 505 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 506 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 507 data parameters are used. 508 509 Parameters 510 ---------- 511 episode_name : str 512 the name of the episode 513 load_episode : str, optional 514 the (previously run) episode name to load the model from 515 parameters_update : dict, optional 516 the dictionary used to update the parameters from the config 517 parameters_update_second : dict, optional 518 the dictionary used to update the parameters after the automatic fill-out 519 load_epoch : int, optional 520 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 521 load_search : str, optional 522 the hyperparameter search result to load 523 load_parameters : list, optional 524 a list of string names of the parameters to load from load_search (if not provided, all parameters 525 are loaded) 526 round_to_binary : list, optional 527 a list of string names of the loaded parameters that should be rounded to the nearest power of two 528 purpose : {"train", "continuing", "prediction"} 529 the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing 530 the training of an interrupted episode, `"prediction"` for generating a prediction) 531 task : TaskDispatcher, optional 532 a pre-existing task; if provided, the method will update the task instead of creating a new one 533 (this might save time, mainly on dataset loading) 534 535 Returns 536 ------- 537 task : TaskDispatcher 538 the `TaskDispatcher` instance 539 parameters : dict 540 the parameters dictionary that describes the task 541 542 """ 543 parameters = self._make_parameters( 544 episode_name, 545 load_episode, 546 parameters_update, 547 parameters_update_second, 548 load_epoch, 549 load_search, 550 load_parameters, 551 round_to_binary, 552 purpose, 553 load_strict=load_strict, 554 ) 555 if task is None: 556 task = TaskDispatcher(parameters) 557 else: 558 task.update_task(parameters) 559 self._save_stores(parameters) 560 return task, parameters 561 562 def get_decision_thresholds( 563 self, 564 episode_names: List, 565 metric_name: str = "f1", 566 parameters_update: Dict = None, 567 load_epochs: List = None, 568 remove_saved_features: bool = False, 569 ) -> Tuple[List, List, TaskDispatcher]: 570 """Compute optimal decision thresholds or load them if they have been computed before. 571 572 Parameters 573 ---------- 574 episode_names : List 575 a list of episode names 576 metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"} 577 the metric to optimize 578 parameters_update : dict, optional 579 the parameter update dictionary 580 load_epochs : list, optional 581 a list of epochs to load (by default last are loaded) 582 remove_saved_features : bool, default False 583 if `True`, the dataset will be deleted after the computation 584 585 Returns 586 ------- 587 thresholds : list 588 a list of float decision threshold values 589 classes : list 590 the label names corresponding to the values 591 task : TaskDispatcher | None 592 the task used in computation 593 594 """ 595 parameters = self._make_parameters( 596 "_", 597 episode_names[0], 598 parameters_update, 599 {}, 600 load_epochs[0], 601 purpose="prediction", 602 ) 603 thresholds = self._thresholds().find_thresholds( 604 episode_names, 605 load_epochs, 606 metric_name, 607 metric_parameters=parameters["metrics"][metric_name], 608 ) 609 task = None 610 behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values()) 611 return thresholds, behaviors, task 612 613 def run_episode( 614 self, 615 episode_name: str, 616 load_episode: str = None, 617 parameters_update: Dict = None, 618 task: TaskDispatcher = None, 619 load_epoch: int = None, 620 load_search: str = None, 621 load_parameters: list = None, 622 round_to_binary: list = None, 623 load_strict: bool = True, 624 n_seeds: int = 1, 625 force: bool = False, 626 suppress_name_check: bool = False, 627 remove_saved_features: bool = False, 628 mask_name: str = None, 629 autostop_metric: str = None, 630 autostop_interval: int = 50, 631 autostop_threshold: float = 0.001, 632 loading_bar: bool = False, 633 trial: Tuple = None, 634 ) -> TaskDispatcher: 635 """Run an episode. 636 637 The task parameters are read from the config files and then updated with the 638 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 639 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 640 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 641 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 642 data parameters are used. 643 644 You can use the autostop parameters to finish training when the parameters are not improving. It will be 645 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 646 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 647 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 648 649 Parameters 650 ---------- 651 episode_name : str 652 the episode name 653 load_episode : str, optional 654 the (previously run) episode name to load the model from; if the episode has multiple runs, 655 the new episode will have the same number of runs, each starting with one of the pre-trained models 656 parameters_update : dict, optional 657 the dictionary used to update the parameters from the config files 658 task : TaskDispatcher, optional 659 a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating 660 a new instance) 661 load_epoch : int, optional 662 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 663 load_search : str, optional 664 the hyperparameter search result to load 665 load_parameters : list, optional 666 a list of string names of the parameters to load from load_search (if not provided, all parameters 667 are loaded) 668 round_to_binary : list, optional 669 a list of string names of the loaded parameters that should be rounded to the nearest power of two 670 load_strict : bool, default True 671 if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and 672 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` 673 n_seeds : int, default 1 674 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g. 675 `test_episode#0` and `test_episode#1` 676 force : bool, default False 677 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 678 suppress_name_check : bool, default False 679 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 680 why they are usually forbidden) 681 remove_saved_features : bool, default False 682 if `True`, the dataset will be deleted after training 683 mask_name : str, optional 684 the name of the real_lens to apply 685 autostop_metric : str, optional 686 the autostop metric (can be any one of the tracked metrics of `'loss'`) 687 autostop_interval : int, default 50 688 the number of epochs to average the autostop metric over 689 autostop_threshold : float, default 0.001 690 the autostop difference threshold 691 loading_bar : bool, default False 692 if `True`, a loading bar will be displayed 693 trial : tuple, optional 694 a tuple of (trial, metric) for hyperparameter search 695 696 Returns 697 ------- 698 TaskDispatcher 699 the `TaskDispatcher` object 700 701 """ 702 703 import gc 704 705 gc.collect() 706 if torch.cuda.is_available(): 707 torch.cuda.empty_cache() 708 709 if type(n_seeds) is not int or n_seeds < 1: 710 raise ValueError( 711 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 712 ) 713 if n_seeds > 1 and mask_name is not None: 714 raise ValueError("Cannot apply a real_lens with n_seeds > 1") 715 self._check_episode_validity( 716 episode_name, allow_doublecolon=suppress_name_check, force=force 717 ) 718 load_runs = self._episodes().get_runs(load_episode) 719 if len(load_runs) > 1: 720 task = self.run_episodes( 721 episode_names=[ 722 f'{episode_name}#{run.split("#")[-1]}' for run in load_runs 723 ], 724 load_episodes=load_runs, 725 parameters_updates=[parameters_update for _ in load_runs], 726 load_epochs=[load_epoch for _ in load_runs], 727 load_searches=[load_search for _ in load_runs], 728 load_parameters=[load_parameters for _ in load_runs], 729 round_to_binary=[round_to_binary for _ in load_runs], 730 load_strict=[load_strict for _ in load_runs], 731 suppress_name_check=True, 732 force=force, 733 remove_saved_features=False, 734 ) 735 if remove_saved_features: 736 self._remove_stores( 737 { 738 "general": task.general_parameters, 739 "data": task.data_parameters, 740 "features": task.feature_parameters, 741 } 742 ) 743 if n_seeds > 1: 744 warnings.warn( 745 f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs" 746 ) 747 elif n_seeds > 1: 748 749 self.run_episodes( 750 episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)], 751 load_episodes=[load_episode for _ in range(n_seeds)], 752 parameters_updates=[parameters_update for _ in range(n_seeds)], 753 load_epochs=[load_epoch for _ in range(n_seeds)], 754 load_searches=[load_search for _ in range(n_seeds)], 755 load_parameters=[load_parameters for _ in range(n_seeds)], 756 round_to_binary=[round_to_binary for _ in range(n_seeds)], 757 load_strict=[load_strict for _ in range(n_seeds)], 758 suppress_name_check=True, 759 force=force, 760 remove_saved_features=remove_saved_features, 761 ) 762 else: 763 print(f"TRAINING {episode_name}") 764 try: 765 task, parameters = self._make_task_training( 766 episode_name, 767 load_episode, 768 parameters_update, 769 load_epoch, 770 load_search, 771 load_parameters, 772 round_to_binary, 773 continuing=False, 774 task=task, 775 mask_name=mask_name, 776 load_strict=load_strict, 777 ) 778 self._save_episode( 779 episode_name, 780 parameters, 781 task.behaviors_dict(), 782 norm_stats=task.get_normalization_stats(), 783 ) 784 time_start = time.time() 785 if trial is not None: 786 trial, metric = trial 787 else: 788 trial, metric = None, None 789 logs = task.train( 790 autostop_metric=autostop_metric, 791 autostop_interval=autostop_interval, 792 autostop_threshold=autostop_threshold, 793 loading_bar=loading_bar, 794 trial=trial, 795 optimized_metric=metric, 796 ) 797 time_end = time.time() 798 time_total = time_end - time_start 799 hours = int(time_total // 3600) 800 time_total -= hours * 3600 801 minutes = int(time_total // 60) 802 time_total -= minutes * 60 803 seconds = int(time_total) 804 training_time = f"{hours}:{minutes:02}:{seconds:02}" 805 self._update_episode_results(episode_name, logs, training_time) 806 if remove_saved_features: 807 self._remove_stores(parameters) 808 print("\n") 809 return task 810 811 except Exception as e: 812 if isinstance(e, optuna.exceptions.TrialPruned): 813 raise e 814 else: 815 # if str(e) != f"The {episode_name} episode name is already in use!": 816 # self.remove_episode(episode_name) 817 raise RuntimeError(f"Episode {episode_name} could not run") 818 819 def run_episodes( 820 self, 821 episode_names: List, 822 load_episodes: List = None, 823 parameters_updates: List = None, 824 load_epochs: List = None, 825 load_searches: List = None, 826 load_parameters: List = None, 827 round_to_binary: List = None, 828 load_strict: List = None, 829 force: bool = False, 830 suppress_name_check: bool = False, 831 remove_saved_features: bool = False, 832 ) -> TaskDispatcher: 833 """Run multiple episodes in sequence (and re-use previously loaded information). 834 835 For each episode, the task parameters are read from the config files and then updated with the 836 parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the 837 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 838 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 839 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 840 data parameters are used. 841 842 Parameters 843 ---------- 844 episode_names : list 845 a list of strings of episode names 846 load_episodes : list, optional 847 a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, 848 the new episode will have the same number of runs, each starting with one of the pre-trained models 849 parameters_updates : list, optional 850 a list of dictionaries used to update the parameters from the config 851 load_epochs : list, optional 852 a list of integers used to specify the epoch to load (if load_episodes is not None) 853 load_searches : list, optional 854 a list of strings of hyperparameter search results to load 855 load_parameters : list, optional 856 a list of lists of string names of the parameters to load from the searches 857 round_to_binary : list, optional 858 a list of string names of the loaded parameters that should be rounded to the nearest power of two 859 load_strict : list, optional 860 a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from 861 the corresponding episode and differences in parameter name lists and 862 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for 863 every episode) 864 force : bool, default False 865 if `True` and an episode name is already taken, it will be overwritten (use with caution!) 866 suppress_name_check : bool, default False 867 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 868 why they are usually forbidden) 869 remove_saved_features : bool, default False 870 if `True`, the dataset will be deleted after training 871 872 Returns 873 ------- 874 TaskDispatcher 875 the task dispatcher object 876 877 """ 878 task = None 879 if load_searches is None: 880 load_searches = [None for _ in episode_names] 881 if load_episodes is None: 882 load_episodes = [None for _ in episode_names] 883 if parameters_updates is None: 884 parameters_updates = [None for _ in episode_names] 885 if load_parameters is None: 886 load_parameters = [None for _ in episode_names] 887 if load_epochs is None: 888 load_epochs = [None for _ in episode_names] 889 if load_strict is None: 890 load_strict = [True for _ in episode_names] 891 for ( 892 parameters_update, 893 episode_name, 894 load_episode, 895 load_epoch, 896 load_search, 897 load_parameters_list, 898 load_strict_value, 899 ) in zip( 900 parameters_updates, 901 episode_names, 902 load_episodes, 903 load_epochs, 904 load_searches, 905 load_parameters, 906 load_strict, 907 ): 908 task = self.run_episode( 909 episode_name, 910 load_episode, 911 parameters_update, 912 task, 913 load_epoch, 914 load_search, 915 load_parameters_list, 916 round_to_binary, 917 load_strict_value, 918 suppress_name_check=suppress_name_check, 919 force=force, 920 remove_saved_features=remove_saved_features, 921 ) 922 return task 923 924 def continue_episode( 925 self, 926 episode_name: str, 927 num_epochs: int = None, 928 task: TaskDispatcher = None, 929 n_seeds: int = 1, 930 remove_saved_features: bool = False, 931 device: str = "cuda", 932 num_cpus: int = None, 933 ) -> TaskDispatcher: 934 """Load an older episode and continue running from the latest checkpoint. 935 936 All parameters as well as the model and optimizer state dictionaries are loaded from the episode. 937 938 Parameters 939 ---------- 940 episode_name : str 941 the name of the episode to continue 942 num_epochs : int, optional 943 the new number of epochs 944 task : TaskDispatcher, optional 945 a pre-existing task; if provided, the method will update the task instead of creating a new one 946 (this might save time, mainly on dataset loading) 947 n_seeds : int, default 1 948 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g. 949 `test_episode#0` and `test_episode#1` 950 remove_saved_features : bool, default False 951 if `True`, pre-computed features will be deleted after the run 952 device : str, default "cuda" 953 the torch device to use 954 num_cpus : int, optional 955 the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used 956 957 Returns 958 ------- 959 TaskDispatcher 960 the task dispatcher 961 962 """ 963 runs = self._episodes().get_runs(episode_name) 964 for run in runs: 965 print(f"TRAINING {run}") 966 if num_epochs is None and not self._episode(run).unfinished(): 967 continue 968 parameters_update = { 969 "training": { 970 "num_epochs": num_epochs, 971 "device": device, 972 }, 973 "general": {"num_cpus": num_cpus}, 974 } 975 task, parameters = self._make_task_training( 976 run, 977 load_episode=run, 978 parameters_update=parameters_update, 979 continuing=True, 980 task=task, 981 ) 982 time_start = time.time() 983 logs = task.train() 984 time_end = time.time() 985 old_time = self._training_time(run) 986 if not np.isnan(old_time): 987 time_end += old_time 988 time_total = time_end - time_start 989 hours = int(time_total // 3600) 990 time_total -= hours * 3600 991 minutes = int(time_total // 60) 992 time_total -= minutes * 60 993 seconds = int(time_total) 994 training_time = f"{hours}:{minutes:02}:{seconds:02}" 995 else: 996 training_time = np.nan 997 self._save_episode( 998 run, 999 parameters, 1000 task.behaviors_dict(), 1001 suppress_validation=True, 1002 training_time=training_time, 1003 norm_stats=task.get_normalization_stats(), 1004 ) 1005 self._update_episode_results(run, logs) 1006 print("\n") 1007 if len(runs) < n_seeds: 1008 for i in range(len(runs), n_seeds): 1009 self.run_episode( 1010 f"{episode_name}#{i}", 1011 parameters_update=self._episodes().load_parameters(runs[0]), 1012 task=task, 1013 suppress_name_check=True, 1014 ) 1015 if remove_saved_features: 1016 self._remove_stores(parameters) 1017 return task 1018 1019 def run_default_hyperparameter_search( 1020 self, 1021 search_name: str, 1022 model_name: str, 1023 metric: str = "f1", 1024 best_n: int = 3, 1025 direction: str = "maximize", 1026 load_episode: str = None, 1027 load_epoch: int = None, 1028 load_strict: bool = True, 1029 prune: bool = True, 1030 force: bool = False, 1031 remove_saved_features: bool = False, 1032 overlap: float = 0, 1033 num_epochs: int = 50, 1034 test_frac: float = None, 1035 n_trials=150, 1036 batch_size=32, 1037 ): 1038 """Run an optuna hyperparameter search with default parameters for a model. 1039 1040 For the vast majority of cases, optimizing the default parameters should be enough. 1041 Check out `dlc2action.options.model_hyperparameters` for the lists of parameters. 1042 There are also options to set overlap, test fraction and number of epochs parameters for the search without 1043 modifying the project config files. However, if you want something more complex, look into 1044 `Project.run_hyperparameter_search`. 1045 1046 The task parameters are read from the config files and updated with the parameters_update dictionary. 1047 The model can be either initialized from scratch or loaded from a previously run episode. 1048 For each trial, the objective metric is averaged over a few best epochs. 1049 1050 Parameters 1051 ---------- 1052 search_name : str 1053 the name of the search to store it in the meta files and load in run_episode 1054 model_name : str 1055 the name 1056 metric : str 1057 the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to 1058 `"none"` in the config files, it will be reset to `"macro"` for the search 1059 best_n : int, default 1 1060 the number of epochs to average the metric; if 0, the last value is taken 1061 direction : {'maximize', 'minimize'} 1062 optimization direction 1063 load_episode : str, optional 1064 the name of the episode to load the model from 1065 load_epoch : int, optional 1066 the epoch to load the model from (if not provided, the last checkpoint is used) 1067 load_strict : bool, default True 1068 if `True`, the model will be loaded only if the parameters match exactly 1069 prune : bool, default False 1070 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1071 (with optuna HyperBand pruner) 1072 force : bool, default False 1073 if `True`, existing searches with the same name will be overwritten 1074 remove_saved_features : bool, default False 1075 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1076 overlap : float, default 0 1077 the overlap to use for the search 1078 num_epochs : int, default 50 1079 the number of epochs to use for the search 1080 test_frac : float, optional 1081 the test fraction to use for the search 1082 n_trials : int, default 150 1083 the number of trials to run 1084 batch_size : int, default 32 1085 the batch size to use for the search 1086 1087 Returns 1088 ------- 1089 best_parameters : dict 1090 a dictionary of best parameters 1091 1092 """ 1093 if model_name not in options.model_hyperparameters: 1094 raise ValueError( 1095 f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()" 1096 ) 1097 pars = { 1098 "general": {"overlap": overlap, "model_name": model_name}, 1099 "training": {"num_epochs": num_epochs, "batch_size": batch_size}, 1100 } 1101 if test_frac is not None: 1102 pars["training"]["test_frac"] = test_frac 1103 if not metric.split("_")[-1].isnumeric(): 1104 project_pars = self._read_parameters() 1105 if project_pars["metrics"][metric].get("average") == "none": 1106 pars["metrics"] = {metric: {"average": "macro"}} 1107 return self.run_hyperparameter_search( 1108 search_name=search_name, 1109 search_space=options.model_hyperparameters[model_name], 1110 metric=metric, 1111 n_trials=n_trials, 1112 best_n=best_n, 1113 parameters_update=pars, 1114 direction=direction, 1115 load_episode=load_episode, 1116 load_epoch=load_epoch, 1117 load_strict=load_strict, 1118 prune=prune, 1119 force=force, 1120 remove_saved_features=remove_saved_features, 1121 ) 1122 1123 def run_hyperparameter_search( 1124 self, 1125 search_name: str, 1126 search_space: Dict, 1127 metric: str = "f1", 1128 n_trials: int = 20, 1129 best_n: int = 1, 1130 parameters_update: Dict = None, 1131 direction: str = "maximize", 1132 load_episode: str = None, 1133 load_epoch: int = None, 1134 load_strict: bool = True, 1135 prune: bool = False, 1136 force: bool = False, 1137 remove_saved_features: bool = False, 1138 make_plots: bool = True, 1139 ) -> Dict: 1140 """Run an optuna hyperparameter search. 1141 1142 For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`. 1143 1144 To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is 1145 a dictionary where keys are model names and values are default search spaces. 1146 1147 The task parameters are read from the config files and updated with the parameters_update dictionary. 1148 The model can be either initialized from scratch or loaded from a previously run episode. 1149 For each trial, the objective metric is averaged over a few best epochs. 1150 1151 Parameters 1152 ---------- 1153 search_name : str 1154 the name of the search to store it in the meta files and load in run_episode 1155 search_space : dict 1156 a dictionary representing the search space; of this general structure: 1157 {'group/param_name': ('float/int/float_log/int_log', start, end), 1158 'group/param_name': ('categorical', [choices])}, e.g. 1159 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 1160 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}; 1161 metric : str, default f1 1162 the metric to maximize/minimize (see direction) 1163 n_trials : int, default 20 1164 the number of optimization trials to run 1165 best_n : int, default 1 1166 the number of epochs to average the metric; if 0, the last value is taken 1167 parameters_update : dict, optional 1168 the parameters update dictionary 1169 direction : {'maximize', 'minimize'} 1170 optimization direction 1171 load_episode : str, optional 1172 the name of the episode to load the model from 1173 load_epoch : int, optional 1174 the epoch to load the model from (if not provided, the last checkpoint is used) 1175 load_strict : bool, default True 1176 if `True`, the model will be loaded only if the parameters match exactly 1177 prune : bool, default False 1178 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1179 (with optuna HyperBand pruner) 1180 force : bool, default False 1181 if `True`, existing searches with the same name will be overwritten 1182 remove_saved_features : bool, default False 1183 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1184 1185 Returns 1186 ------- 1187 dict 1188 a dictionary of best parameters 1189 1190 """ 1191 self._check_search_validity(search_name, force=force) 1192 print(f"SEARCH {search_name}") 1193 self.remove_episode(f"_{search_name}") 1194 if parameters_update is None: 1195 parameters_update = {} 1196 parameters_update = self._update( 1197 parameters_update, {"general": {"metric_functions": {metric}}} 1198 ) 1199 parameters = self._make_parameters( 1200 f"_{search_name}", 1201 load_episode, 1202 parameters_update, 1203 parameters_update_second={"training": {"model_save_path": None}}, 1204 load_epoch=load_epoch, 1205 load_strict=load_strict, 1206 ) 1207 task = None 1208 1209 if prune: 1210 pruner = optuna.pruners.HyperbandPruner() 1211 else: 1212 pruner = optuna.pruners.NopPruner() 1213 study = optuna.create_study(direction=direction, pruner=pruner) 1214 runner = _Runner( 1215 search_space=search_space, 1216 load_episode=load_episode, 1217 load_epoch=load_epoch, 1218 metric=metric, 1219 average=best_n, 1220 task=task, 1221 remove_saved_features=remove_saved_features, 1222 project=self, 1223 search_name=search_name, 1224 ) 1225 study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials) 1226 if make_plots: 1227 search_path = self._search_path(search_name) 1228 os.mkdir(search_path) 1229 fig = optuna.visualization.plot_contour(study) 1230 plotly.offline.plot( 1231 fig, filename=os.path.join(search_path, f"{search_name}_contour.html") 1232 ) 1233 fig = optuna.visualization.plot_param_importances(study) 1234 plotly.offline.plot( 1235 fig, 1236 filename=os.path.join(search_path, f"{search_name}_importances.html"), 1237 ) 1238 best_params = study.best_params 1239 best_value = study.best_value 1240 if best_value == 0 or best_value == float("inf"): 1241 raise ValueError( 1242 f"Best metric value is {best_value}, check your partition method and make sure that all behaviors are present in the validation set!" 1243 ) 1244 self._save_search( 1245 search_name, 1246 parameters, 1247 n_trials, 1248 best_params, 1249 best_value, 1250 metric, 1251 search_space, 1252 ) 1253 self.remove_episode(f"_{search_name}") 1254 runner.clean() 1255 print(f"best parameters: {best_params}") 1256 print("\n") 1257 return best_params 1258 1259 def run_prediction( 1260 self, 1261 prediction_name: str, 1262 episode_names: List, 1263 load_epochs: List = None, 1264 parameters_update: Dict = None, 1265 augment_n: int = 10, 1266 data_path: str = None, 1267 mode: str = "all", 1268 file_paths: Set = None, 1269 remove_saved_features: bool = False, 1270 frame_number_map_file: str = None, 1271 force: bool = False, 1272 embedding: bool = False, 1273 ) -> None: 1274 """Load models from previously run episodes to generate a prediction. 1275 1276 The probabilities predicted by the models are averaged. 1277 Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder 1278 under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level 1279 keys are the video ids, the second-level keys are the clip ids (like individual names) and the values 1280 are the prediction arrays. 1281 1282 Parameters 1283 ---------- 1284 prediction_name : str 1285 the name of the prediction 1286 episode_names : list 1287 a list of string episode names to load the models from 1288 load_epochs : list or int, optional 1289 a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes 1290 parameters_update : dict, optional 1291 a dictionary of parameter updates 1292 augment_n : int, default 10 1293 the number of augmentations to average over 1294 data_path : str, optional 1295 the data path to run the prediction for 1296 mode : {'all', 'test', 'val', 'train'} 1297 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1298 file_paths : set, optional 1299 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1300 for 1301 remove_saved_features : bool, default False 1302 if `True`, pre-computed features will be deleted 1303 submission : bool, default False 1304 if `True`, a MABe-22 style submission file is generated 1305 frame_number_map_file : str, optional 1306 path to the frame number map file 1307 force : bool, default False 1308 if `True`, existing prediction with this name will be overwritten 1309 embedding : bool, default False 1310 if `True`, the prediction is made for the embedding task 1311 1312 """ 1313 self._check_prediction_validity(prediction_name, force=force) 1314 print(f"PREDICTION {prediction_name}") 1315 task, parameters, mode, prediction, inference_time, behavior_dict = ( 1316 self._make_prediction( 1317 prediction_name, 1318 episode_names, 1319 load_epochs, 1320 parameters_update, 1321 data_path, 1322 file_paths, 1323 mode, 1324 augment_n, 1325 evaluate=False, 1326 embedding=embedding, 1327 ) 1328 ) 1329 predicted = task.dataset(mode).generate_full_length_prediction(prediction) 1330 1331 if remove_saved_features: 1332 self._remove_stores(parameters) 1333 1334 self._save_prediction( 1335 prediction_name, 1336 predicted, 1337 parameters, 1338 task, 1339 mode, 1340 embedding, 1341 inference_time, 1342 behavior_dict, 1343 ) 1344 print("\n") 1345 1346 def evaluate_prediction( 1347 self, 1348 prediction_name: str, 1349 parameters_update: Dict = None, 1350 data_path: str = None, 1351 annotation_path: str = None, 1352 file_paths: Set = None, 1353 mode: str = None, 1354 remove_saved_features: bool = False, 1355 annotation_type: str = "none", 1356 num_classes: int = None, # Set when using data_path 1357 ) -> Tuple[float, dict]: 1358 """Make predictions and evaluate them 1359 inputs: 1360 prediction_name (str): the name of the prediction 1361 parameters_update (dict): a dictionary of parameter updates 1362 data_path (str): the data path to run the prediction for 1363 annotation_path (str): the annotation path to run the prediction for 1364 file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for 1365 mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1366 remove_saved_features (bool): if `True`, pre-computed features will be deleted 1367 annotation_type (str): the type of annotation to use for evaluation 1368 num_classes (int): the number of classes in the dataset, must be set with data_path 1369 outputs: 1370 results (dict): a dictionary of average values of metric functions 1371 """ 1372 1373 prediction_path = os.path.join( 1374 self.project_path, "results", "predictions", f"{prediction_name}" 1375 ) 1376 prediction_dict = {} 1377 for prediction_file_path in [ 1378 os.path.join(prediction_path, i) for i in os.listdir(prediction_path) 1379 ]: 1380 with open(os.path.join(prediction_file_path), "rb") as f: 1381 prediction = pickle.load(f) 1382 video_id = os.path.basename(prediction_file_path).split( 1383 "_" + prediction_name 1384 )[0] 1385 prediction_dict[video_id] = prediction 1386 if parameters_update is None: 1387 parameters_update = {} 1388 parameters_update = self._update( 1389 self._predictions().load_parameters(prediction_name), parameters_update 1390 ) 1391 parameters_update.pop("model") 1392 if not data_path is None: 1393 assert ( 1394 not num_classes is None 1395 ), "num_classes must be provided if data_path is provided" 1396 parameters_update["general"]["num_classes"] = num_classes + int( 1397 parameters_update["general"]["exclusive"] 1398 ) 1399 task, parameters, mode = self._make_task_prediction( 1400 "_", 1401 load_episode=None, 1402 parameters_update=parameters_update, 1403 data_path=data_path, 1404 annotation_path=annotation_path, 1405 file_paths=file_paths, 1406 mode=mode, 1407 annotation_type=annotation_type, 1408 ) 1409 results = task.evaluate_prediction(prediction_dict, data=mode) 1410 if remove_saved_features: 1411 self._remove_stores(parameters) 1412 results = Project._reformat_results( 1413 results[1], 1414 task.behaviors_dict(), 1415 exclusive=task.general_parameters["exclusive"], 1416 ) 1417 return results 1418 1419 def evaluate( 1420 self, 1421 episode_names: List, 1422 load_epochs: List = None, 1423 augment_n: int = 0, 1424 data_path: str = None, 1425 file_paths: Set = None, 1426 mode: str = None, 1427 parameters_update: Dict = None, 1428 multiple_episode_policy: str = "average", 1429 remove_saved_features: bool = False, 1430 skip_updating_meta: bool = True, 1431 annotation_type: str = "none", 1432 ) -> Dict: 1433 """Load one or several models from previously run episodes to make an evaluation. 1434 1435 By default it will run on the test (or validation, if there is no test) subset of the project dataset. 1436 1437 Parameters 1438 ---------- 1439 episode_names : list 1440 a list of string episode names to load the models from 1441 load_epochs : list, optional 1442 a list of integer epoch indices to load the model from; if None, the last ones are used 1443 augment_n : int, default 0 1444 the number of augmentations to average over 1445 data_path : str, optional 1446 the data path to run the prediction for 1447 file_paths : set, optional 1448 a set of files to run the prediction for 1449 mode : {'test', 'val', 'train', 'all'} 1450 the subset of the data to make the prediction for (forced to 'all' if data_path is not None; 1451 by default 'test' if test subset is not empty and 'val' otherwise) 1452 parameters_update : dict, optional 1453 a dictionary with parameter updates (cannot change model parameters) 1454 multiple_episode_policy : {'average', 'statistics'} 1455 the policy to use when multiple episodes are provided 1456 remove_saved_features : bool, default False 1457 if `True`, the dataset will be deleted 1458 skip_updating_meta : bool, default True 1459 if `True`, the meta file will not be updated with the computed metrics 1460 1461 Returns 1462 ------- 1463 metric : dict 1464 a dictionary of average values of metric functions 1465 1466 """ 1467 names = [] 1468 for episode_name in episode_names: 1469 names += self._episodes().get_runs(episode_name) 1470 if len(set(episode_names)) == 1: 1471 print(f"EVALUATION {episode_names[0]}") 1472 else: 1473 print(f"EVALUATION {episode_names}") 1474 if len(names) > 1: 1475 evaluate = True 1476 else: 1477 evaluate = False 1478 if multiple_episode_policy == "average": 1479 task, parameters, mode, prediction, inference_time, behavior_dict = ( 1480 self._make_prediction( 1481 "_", 1482 episode_names, 1483 load_epochs, 1484 parameters_update, 1485 mode=mode, 1486 data_path=data_path, 1487 file_paths=file_paths, 1488 augment_n=augment_n, 1489 evaluate=evaluate, 1490 annotation_type=annotation_type, 1491 ) 1492 ) 1493 print("EVALUATE PREDICTION:") 1494 indices = [ 1495 list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict)) 1496 ] 1497 _, results = task.evaluate_prediction( 1498 prediction, data=mode, indices=indices 1499 ) 1500 if len(names) == 1 and mode == "val" and not skip_updating_meta: 1501 self._update_episode_metrics(names[0], results) 1502 results = Project._reformat_results( 1503 results, 1504 behavior_dict, 1505 exclusive=task.general_parameters["exclusive"], 1506 ) 1507 1508 elif multiple_episode_policy == "statistics": 1509 values = defaultdict(lambda: []) 1510 task = None 1511 for name in names: 1512 ( 1513 task, 1514 parameters, 1515 mode, 1516 prediction, 1517 inference_time, 1518 behavior_dict, 1519 ) = self._make_prediction( 1520 "_", 1521 [name], 1522 load_epochs, 1523 parameters_update, 1524 mode=mode, 1525 data_path=data_path, 1526 file_paths=file_paths, 1527 augment_n=augment_n, 1528 evaluate=evaluate, 1529 task=task, 1530 ) 1531 _, metrics = task.evaluate_prediction( 1532 prediction, data=mode, indices=list(behavior_dict.keys()) 1533 ) 1534 for name, value in metrics.items(): 1535 values[name].append(value) 1536 if mode == "val" and not skip_updating_meta: 1537 self._update_episode_metrics(name, metrics) 1538 results = defaultdict(lambda: {}) 1539 mean_string = "" 1540 std_string = "" 1541 for key, value_list in values.items(): 1542 results[key]["mean"] = np.mean(value_list) 1543 results[key]["std"] = np.std(value_list) 1544 results[key]["all"] = value_list 1545 mean_string += f"{key} {np.mean(value_list):.3f}, " 1546 std_string += f"{key} {np.std(value_list):.3f}, " 1547 print("MEAN:") 1548 print(mean_string) 1549 print("STD:") 1550 print(std_string) 1551 else: 1552 raise ValueError( 1553 f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose " 1554 f"from ['average', 'statistics']" 1555 ) 1556 if len(names) > 0 and remove_saved_features: 1557 self._remove_stores(parameters) 1558 print(f"Inference time: {inference_time}") 1559 print("\n") 1560 return results 1561 1562 def run_suggestion( 1563 self, 1564 suggestions_name: str, 1565 error_episode: str = None, 1566 error_load_epoch: int = None, 1567 error_class: str = None, 1568 suggestions_prediction: str = None, 1569 suggestion_episodes: List = [None], 1570 suggestion_load_epoch: int = None, 1571 suggestion_classes: List = None, 1572 error_threshold: float = 0.5, 1573 error_threshold_diff: float = 0.1, 1574 error_hysteresis: bool = False, 1575 suggestion_threshold: Union[float, List] = 0.5, 1576 suggestion_threshold_diff: Union[float, List] = 0.1, 1577 suggestion_hysteresis: Union[bool, List] = True, 1578 min_frames_suggestion: int = 10, 1579 min_frames_al: int = 30, 1580 visibility_min_score: float = 0, 1581 visibility_min_frac: float = 0.7, 1582 augment_n: int = 0, 1583 exclude_classes: List = None, 1584 exclude_threshold: Union[float, List] = 0.6, 1585 exclude_threshold_diff: Union[float, List] = 0.1, 1586 exclude_hysteresis: Union[bool, List] = False, 1587 include_classes: List = None, 1588 include_threshold: Union[float, List] = 0.4, 1589 include_threshold_diff: Union[float, List] = 0.1, 1590 include_hysteresis: Union[bool, List] = False, 1591 data_path: str = None, 1592 file_paths: Set = None, 1593 parameters_update: Dict = None, 1594 mode: str = "all", 1595 force: bool = False, 1596 remove_saved_features: bool = False, 1597 cut_annotated: bool = False, 1598 background_threshold: float = None, 1599 ) -> None: 1600 """Create active learning and suggestion files. 1601 1602 Generate predictions with the error and suggestion model and use them to create 1603 suggestion files for the labeling interface. Those files will render as suggested labels 1604 at intervals with high pose estimation quality. Quality here is defined by probability of error 1605 (predicted by the error model) and visibility parameters. 1606 1607 If `error_episode` or `exclude_classes` is not `None`, 1608 an active learning file will be created as well (with frames with high predicted probability of classes 1609 from `exclude_classes` and/or errors excluded from the active learning intervals). 1610 1611 In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals) 1612 you can apply one of three methods. 1613 1614 - **Simple threshold** 1615 1616 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold` 1617 parameter to $\alpha$. 1618 In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will 1619 be considered labeled. 1620 1621 - **Hysteresis threshold** 1622 1623 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold` 1624 parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$. 1625 Now intervals will be marked with a label if the probability of that label for all frames is higher 1626 than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$. 1627 1628 - **Max hysteresis threshold** 1629 1630 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold` 1631 parameter to $\alpha$ and the `threshold_diff` parameter to `None`. 1632 With this combination intervals are marked with a label if that label is more likely than any other 1633 for all frames in this interval and at for at least one of those frames its probability is higher than 1634 $\alpha$. 1635 1636 Parameters 1637 ---------- 1638 suggestions_name : str 1639 the name of the suggestions 1640 error_episode : str, optional 1641 the name of the episode where the error model should be loaded from 1642 error_load_epoch : int, optional 1643 the epoch the error model should be loaded from 1644 error_class : str, optional 1645 the name of the error class (in `error_episode`) 1646 suggestions_prediction : str, optional 1647 the name of the predictions that should be used for the suggestion model 1648 suggestion_episodes : list, optional 1649 the names of the episodes where the suggestion models should be loaded from 1650 suggestion_load_epoch : int, optional 1651 the epoch the suggestion model should be loaded from 1652 suggestion_classes : list, optional 1653 a list of string names of the classes that should be suggested (in `suggestion_episode`) 1654 error_threshold : float, default 0.5 1655 the hard threshold for error prediction 1656 error_threshold_diff : float, default 0.1 1657 the difference between soft and hard thresholds for error prediction (in case hysteresis is used) 1658 error_hysteresis : bool, default False 1659 if True, hysteresis is used for error prediction 1660 suggestion_threshold : float | list, default 0.5 1661 the hard threshold for class prediction (use a list to set different rules for different classes) 1662 suggestion_threshold_diff : float | list, default 0.1 1663 the difference between soft and hard thresholds for class prediction (in case hysteresis is used; 1664 use a list to set different rules for different classes) 1665 suggestion_hysteresis : bool | list, default True 1666 if True, hysteresis is used for class prediction (use a list to set different rules for different classes) 1667 min_frames_suggestion : int, default 10 1668 only actions longer than this number of frames will be suggested 1669 min_frames_al : int, default 30 1670 only active learning intervals longer than this number of frames will be suggested 1671 visibility_min_score : float, default 0 1672 the minimum visibility score for visibility filtering 1673 visibility_min_frac : float, default 0.7 1674 the minimum fraction of visible frames for visibility filtering 1675 augment_n : int, default 10 1676 the number of augmentations to average the predictions over 1677 exclude_classes : list, optional 1678 a list of string names of classes that should be excluded from the active learning intervals 1679 exclude_threshold : float | list, default 0.6 1680 the hard threshold for excluded class prediction (use a list to set different rules for different classes) 1681 exclude_threshold_diff : float | list, default 0.1 1682 the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used) 1683 exclude_hysteresis : bool | list, default False 1684 if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes) 1685 include_classes : list, optional 1686 a list of string names of classes that should be included into the active learning intervals 1687 include_threshold : float | list, default 0.6 1688 the hard threshold for included class prediction (use a list to set different rules for different classes) 1689 include_threshold_diff : float | list, default 0.1 1690 the difference between soft and hard thresholds for included class prediction (in case hysteresis is used) 1691 include_hysteresis : bool | list, default False 1692 if True, hysteresis is used for included class prediction (use a list to set different rules for different classes) 1693 data_path : str, optional 1694 the data path to run the prediction for 1695 file_paths : set, optional 1696 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1697 for 1698 parameters_update : dict, optional 1699 the parameters update dictionary 1700 mode : {'all', 'test', 'val', 'train'} 1701 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1702 force : bool, default False 1703 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 1704 remove_saved_features : bool, default False 1705 if `True`, the dataset will be deleted. 1706 cut_annotated : bool, default False 1707 if `True`, annotated frames will be cut from the suggestions 1708 background_threshold : float, default 0.5 1709 the threshold for background prediction 1710 1711 """ 1712 self._check_suggestions_validity(suggestions_name, force=force) 1713 if any([x is None for x in suggestion_episodes]): 1714 suggestion_episodes = None 1715 if error_episode is None and ( 1716 suggestion_episodes is None and suggestions_prediction is None 1717 ): 1718 raise ValueError( 1719 "Both error_episode and suggestion_episode parameters cannot be None at the same time" 1720 ) 1721 print(f"SUGGESTION {suggestions_name}") 1722 task = None 1723 if suggestion_classes is None: 1724 suggestion_classes = [] 1725 if exclude_classes is None: 1726 exclude_classes = [] 1727 if include_classes is None: 1728 include_classes = [] 1729 if isinstance(suggestion_threshold, list): 1730 if len(suggestion_threshold) != len(suggestion_classes): 1731 raise ValueError( 1732 "The suggestion_threshold parameter has to be either a float value or a list of " 1733 f"float values of the same length as suggestion_classes (got a list of length " 1734 f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)" 1735 ) 1736 else: 1737 suggestion_threshold = [suggestion_threshold for _ in suggestion_classes] 1738 if isinstance(suggestion_threshold_diff, list): 1739 if len(suggestion_threshold_diff) != len(suggestion_classes): 1740 raise ValueError( 1741 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1742 f"float values of the same length as suggestion_classes (got a list of length " 1743 f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)" 1744 ) 1745 else: 1746 suggestion_threshold_diff = [ 1747 suggestion_threshold_diff for _ in suggestion_classes 1748 ] 1749 if isinstance(suggestion_hysteresis, list): 1750 if len(suggestion_hysteresis) != len(suggestion_classes): 1751 raise ValueError( 1752 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1753 f"float values of the same length as suggestion_classes (got a list of length " 1754 f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)" 1755 ) 1756 else: 1757 suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes] 1758 if isinstance(exclude_threshold, list): 1759 if len(exclude_threshold) != len(exclude_classes): 1760 raise ValueError( 1761 "The exclude_threshold parameter has to be either a float value or a list of " 1762 f"float values of the same length as exclude_classes (got a list of length " 1763 f"{len(exclude_threshold)} for {len(exclude_classes)} classes)" 1764 ) 1765 else: 1766 exclude_threshold = [exclude_threshold for _ in exclude_classes] 1767 if isinstance(exclude_threshold_diff, list): 1768 if len(exclude_threshold_diff) != len(exclude_classes): 1769 raise ValueError( 1770 "The exclude_threshold_diff parameter has to be either a float value or a list of " 1771 f"float values of the same length as exclude_classes (got a list of length " 1772 f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)" 1773 ) 1774 else: 1775 exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes] 1776 if isinstance(exclude_hysteresis, list): 1777 if len(exclude_hysteresis) != len(exclude_classes): 1778 raise ValueError( 1779 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1780 f"float values of the same length as suggestion_classes (got a list of length " 1781 f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)" 1782 ) 1783 else: 1784 exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes] 1785 if isinstance(include_threshold, list): 1786 if len(include_threshold) != len(include_classes): 1787 raise ValueError( 1788 "The exclude_threshold parameter has to be either a float value or a list of " 1789 f"float values of the same length as exclude_classes (got a list of length " 1790 f"{len(include_threshold)} for {len(include_classes)} classes)" 1791 ) 1792 else: 1793 include_threshold = [include_threshold for _ in include_classes] 1794 if isinstance(include_threshold_diff, list): 1795 if len(include_threshold_diff) != len(include_classes): 1796 raise ValueError( 1797 "The exclude_threshold_diff parameter has to be either a float value or a list of " 1798 f"float values of the same length as exclude_classes (got a list of length " 1799 f"{len(include_threshold_diff)} for {len(include_classes)} classes)" 1800 ) 1801 else: 1802 include_threshold_diff = [include_threshold_diff for _ in include_classes] 1803 if isinstance(include_hysteresis, list): 1804 if len(include_hysteresis) != len(include_classes): 1805 raise ValueError( 1806 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1807 f"float values of the same length as suggestion_classes (got a list of length " 1808 f"{len(include_hysteresis)} for {len(include_classes)} classes)" 1809 ) 1810 else: 1811 include_hysteresis = [include_hysteresis for _ in include_classes] 1812 if (suggestion_episodes is None and suggestions_prediction is None) and len( 1813 exclude_classes 1814 ) > 0: 1815 raise ValueError( 1816 "In order to exclude classes from the active learning intervals you need to set the " 1817 "suggestion_episode parameter" 1818 ) 1819 1820 task = None 1821 if error_episode is not None: 1822 task, parameters, mode = self._make_task_prediction( 1823 prediction_name=suggestions_name, 1824 load_episode=error_episode, 1825 parameters_update=parameters_update, 1826 load_epoch=error_load_epoch, 1827 data_path=data_path, 1828 mode=mode, 1829 file_paths=file_paths, 1830 task=task, 1831 ) 1832 predicted_error = task.predict( 1833 data=mode, 1834 raw_output=True, 1835 apply_primary_function=True, 1836 augment_n=augment_n, 1837 ) 1838 else: 1839 predicted_error = None 1840 1841 if suggestion_episodes is not None: 1842 ( 1843 task, 1844 parameters, 1845 mode, 1846 predicted_classes, 1847 inference_time, 1848 behavior_dict, 1849 ) = self._make_prediction( 1850 prediction_name=suggestions_name, 1851 episode_names=suggestion_episodes, 1852 load_epochs=suggestion_load_epoch, 1853 parameters_update=parameters_update, 1854 data_path=data_path, 1855 file_paths=file_paths, 1856 mode=mode, 1857 task=task, 1858 ) 1859 elif suggestions_prediction is not None: 1860 with open( 1861 os.path.join( 1862 self.project_path, 1863 "results", 1864 "predictions", 1865 f"{suggestions_prediction}.pickle", 1866 ), 1867 "rb", 1868 ) as f: 1869 predicted_classes = pickle.load(f) 1870 if parameters_update is None: 1871 parameters_update = {} 1872 parameters_update = self._update( 1873 self._predictions().load_parameters(suggestions_prediction), 1874 parameters_update, 1875 ) 1876 parameters_update.pop("model") 1877 if suggestion_episodes is None: 1878 suggestion_episodes = [ 1879 os.path.basename( 1880 os.path.dirname( 1881 parameters_update["training"]["checkpoint_path"] 1882 ) 1883 ) 1884 ] 1885 task, parameters, mode = self._make_task_prediction( 1886 "_", 1887 load_episode=None, 1888 parameters_update=parameters_update, 1889 data_path=data_path, 1890 file_paths=file_paths, 1891 mode=mode, 1892 ) 1893 else: 1894 predicted_classes = None 1895 1896 if len(suggestion_classes) > 0 and predicted_classes is not None: 1897 suggestions = self._make_suggestions( 1898 task, 1899 predicted_error, 1900 predicted_classes, 1901 suggestion_threshold, 1902 suggestion_threshold_diff, 1903 suggestion_hysteresis, 1904 suggestion_episodes, 1905 suggestion_classes, 1906 error_threshold, 1907 min_frames_suggestion, 1908 min_frames_al, 1909 visibility_min_score, 1910 visibility_min_frac, 1911 cut_annotated=cut_annotated, 1912 ) 1913 videos = list(suggestions.keys()) 1914 for v_id in videos: 1915 times_dict = defaultdict(lambda: defaultdict(lambda: [])) 1916 clips = set() 1917 for c in suggestions[v_id]: 1918 for start, end, ind in suggestions[v_id][c]: 1919 times_dict[ind][c].append([start, end, 2]) 1920 clips.add(ind) 1921 clips = list(clips) 1922 times_dict = dict(times_dict) 1923 times = [ 1924 [times_dict[ind][c] for c in suggestion_classes] for ind in clips 1925 ] 1926 save_path = self._suggestion_path(v_id, suggestions_name) 1927 Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True) 1928 with open(save_path, "wb") as f: 1929 pickle.dump((None, suggestion_classes, clips, times), f) 1930 1931 if ( 1932 error_episode is not None 1933 or len(exclude_classes) > 0 1934 or len(include_classes) > 0 1935 ): 1936 al_points = self._make_al_points( 1937 task, 1938 predicted_error, 1939 predicted_classes, 1940 exclude_classes, 1941 exclude_threshold, 1942 exclude_threshold_diff, 1943 exclude_hysteresis, 1944 include_classes, 1945 include_threshold, 1946 include_threshold_diff, 1947 include_hysteresis, 1948 error_episode, 1949 error_class, 1950 suggestion_episodes, 1951 error_threshold, 1952 error_threshold_diff, 1953 error_hysteresis, 1954 min_frames_al, 1955 visibility_min_score, 1956 visibility_min_frac, 1957 ) 1958 else: 1959 al_points = self._make_al_points_from_suggestions( 1960 suggestions_name, 1961 task, 1962 predicted_classes, 1963 background_threshold, 1964 visibility_min_score, 1965 visibility_min_frac, 1966 num_behaviors=len(task.behaviors_dict()), 1967 ) 1968 save_path = self._al_points_path(suggestions_name) 1969 Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True) 1970 with open(save_path, "wb") as f: 1971 pickle.dump(al_points, f) 1972 1973 meta_parameters = { 1974 "error_episode": error_episode, 1975 "error_load_epoch": error_load_epoch, 1976 "error_class": error_class, 1977 "suggestion_episode": suggestion_episodes, 1978 "suggestion_load_epoch": suggestion_load_epoch, 1979 "suggestion_classes": suggestion_classes, 1980 "error_threshold": error_threshold, 1981 "error_threshold_diff": error_threshold_diff, 1982 "error_hysteresis": error_hysteresis, 1983 "suggestion_threshold": suggestion_threshold, 1984 "suggestion_threshold_diff": suggestion_threshold_diff, 1985 "suggestion_hysteresis": suggestion_hysteresis, 1986 "min_frames_suggestion": min_frames_suggestion, 1987 "min_frames_al": min_frames_al, 1988 "visibility_min_score": visibility_min_score, 1989 "visibility_min_frac": visibility_min_frac, 1990 "augment_n": augment_n, 1991 "exclude_classes": exclude_classes, 1992 "exclude_threshold": exclude_threshold, 1993 "exclude_threshold_diff": exclude_threshold_diff, 1994 "exclude_hysteresis": exclude_hysteresis, 1995 } 1996 self._save_suggestions(suggestions_name, {}, meta_parameters) 1997 if data_path is not None or file_paths is not None or remove_saved_features: 1998 self._remove_stores(parameters) 1999 print(f"\n") 2000 2001 def _generate_similarity_score( 2002 self, 2003 prediction_name: str, 2004 target_video_id: str, 2005 target_clip: str, 2006 target_start: int, 2007 target_end: int, 2008 ) -> Dict: 2009 with open( 2010 os.path.join( 2011 self.project_path, 2012 "results", 2013 "predictions", 2014 f"{prediction_name}.pickle", 2015 ), 2016 "rb", 2017 ) as f: 2018 prediction = pickle.load(f) 2019 target = prediction[target_video_id][target_clip][:, target_start:target_end] 2020 score_dict = defaultdict(lambda: {}) 2021 for video_id in prediction: 2022 for clip_id in prediction[video_id]: 2023 score_dict[video_id][clip_id] = torch.cdist( 2024 target.T, prediction[video_id][score_dict].T 2025 ).min(0) 2026 return score_dict 2027 2028 def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict: 2029 """Suggest intervals from a score dictionary. 2030 2031 Parameters 2032 ---------- 2033 score_dict : dict 2034 a dictionary containing scores for intervals 2035 min_length : int 2036 minimum length of intervals to suggest 2037 n_intervals : int 2038 number of intervals to suggest 2039 2040 Returns 2041 ------- 2042 intervals : dict 2043 a dictionary of suggested intervals 2044 2045 """ 2046 interval_address = {} 2047 interval_value = {} 2048 s = 0 2049 n = 0 2050 for video_id, video_dict in score_dict.items(): 2051 for clip_id, value in video_dict.items(): 2052 s += value.mean() 2053 n += 1 2054 mean_value = s / n 2055 alpha = 1.75 2056 for it in range(10): 2057 id = 0 2058 interval_address = {} 2059 interval_value = {} 2060 for video_id, video_dict in score_dict.items(): 2061 for clip_id, value in video_dict.items(): 2062 res_indices_start, res_indices_end = apply_threshold( 2063 value, 2064 threshold=(2 - alpha * (0.9**it)) * mean_value, 2065 low=True, 2066 error_mask=None, 2067 min_frames=min_length, 2068 smooth_interval=0, 2069 ) 2070 for start, end in zip(res_indices_start, res_indices_end): 2071 interval_address[id] = [video_id, clip_id, start, end] 2072 interval_value[id] = score_dict[video_id][clip_id][ 2073 start:end 2074 ].mean() 2075 id += 1 2076 if len(interval_address) >= n_intervals: 2077 break 2078 if len(interval_address) < n_intervals: 2079 warnings.warn( 2080 f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals" 2081 ) 2082 sorted_intervals = sorted( 2083 interval_value.items(), key=lambda x: x[1], reverse=True 2084 ) 2085 output_intervals = [ 2086 interval_address[x[0]] 2087 for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)] 2088 ] 2089 output = defaultdict(lambda: []) 2090 for video_id, clip_id, start, end in output_intervals: 2091 output[video_id].append([start, end, clip_id]) 2092 return output 2093 2094 def suggest_intervals_with_similarity( 2095 self, 2096 suggestions_name: str, 2097 prediction_name: str, 2098 target_video_id: str, 2099 target_clip: str, 2100 target_start: int, 2101 target_end: int, 2102 min_length: int = 60, 2103 n_intervals: int = 5, 2104 force: bool = False, 2105 ): 2106 """ 2107 Suggest intervals based on similarity to a target interval. 2108 2109 Parameters 2110 ---------- 2111 suggestions_name : str 2112 Name of the suggestion. 2113 prediction_name : str 2114 Name of the prediction to use. 2115 target_video_id : str 2116 Video id of the target interval. 2117 target_clip : str 2118 Clip id of the target interval. 2119 target_start : int 2120 Start frame of the target interval. 2121 target_end : int 2122 End frame of the target interval. 2123 min_length : int, default 60 2124 Minimum length of the suggested intervals. 2125 n_intervals : int, default 5 2126 Number of suggested intervals. 2127 force : bool, default False 2128 If True, the suggestion is overwritten if it already exists. 2129 2130 """ 2131 self._check_suggestions_validity(suggestions_name, force=force) 2132 print(f"SUGGESTION {suggestions_name}") 2133 score_dict = self._generate_similarity_score( 2134 prediction_name, target_video_id, target_clip, target_start, target_end 2135 ) 2136 intervals = self._suggest_intervals_from_dict( 2137 score_dict, min_length, n_intervals 2138 ) 2139 suggestions_path = os.path.join( 2140 self.project_path, 2141 "results", 2142 "suggestions", 2143 suggestions_name, 2144 ) 2145 if not os.path.exists(suggestions_path): 2146 os.mkdir(suggestions_path) 2147 with open( 2148 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2149 ) as f: 2150 pickle.dump(intervals, f) 2151 meta_parameters = { 2152 "prediction_name": prediction_name, 2153 "min_frames_suggestion": min_length, 2154 "n_intervals": n_intervals, 2155 "target_clip": target_clip, 2156 "target_start": target_start, 2157 "target_end": target_end, 2158 } 2159 self._save_suggestions(suggestions_name, {}, meta_parameters) 2160 print("\n") 2161 2162 def suggest_intervals_with_uncertainty( 2163 self, 2164 suggestions_name: str, 2165 episode_names: List, 2166 load_epochs: List = None, 2167 classes: List = None, 2168 n_frames: int = 10000, 2169 method: str = "least_confidence", 2170 min_length: int = 60, 2171 augment_n: int = 0, 2172 data_path: str = None, 2173 file_paths: Set = None, 2174 parameters_update: Dict = None, 2175 mode: str = "all", 2176 force: bool = False, 2177 remove_saved_features: bool = False, 2178 ) -> None: 2179 """Generate an active learning file based on model uncertainty. 2180 2181 If you provide several episode names, the predicted probabilities will be averaged. 2182 2183 Parameters 2184 ---------- 2185 suggestions_name : str 2186 the name of the suggestion 2187 episode_names : list 2188 a list of string episode names to load the models from 2189 load_epochs : list, optional 2190 a list of epoch indices to load the models from (if `None`, the last ones will be used) 2191 classes : list, optional 2192 a list of classes to look at (by default all) 2193 n_frames : int, default 10000 2194 the threshold total number of frames in the suggested intervals (in the end result it will most likely 2195 be slightly larger; it will only be smaller if the algorithm fails to find enough intervals 2196 with the set parameters) 2197 method : {"least_confidence", "entropy"} 2198 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 2199 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 2200 min_length : int, default 60 2201 the minimum number of frames in one interval 2202 augment_n : int, default 0 2203 the number of augmentations to average the predictions over 2204 data_path : str, optional 2205 the path to a data folder (by default, the project data is used) 2206 file_paths : set, optional 2207 a list of file paths (by default, the project data is used) 2208 parameters_update : dict, optional 2209 a dictionary of parameter updates 2210 mode : {"test", "val", "train", "all"} 2211 the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`; 2212 by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise) 2213 force : bool, default False 2214 if `True`, existing suggestions with the same name will be overwritten 2215 remove_saved_features : bool, default False 2216 if `True`, the dataset will be deleted after the computation 2217 2218 """ 2219 self._check_suggestions_validity(suggestions_name, force=force) 2220 print(f"SUGGESTION {suggestions_name}") 2221 task, parameters, mode, predicted, inference_time, behavior_dict = ( 2222 self._make_prediction( 2223 suggestions_name, 2224 episode_names, 2225 load_epochs, 2226 parameters_update, 2227 data_path=data_path, 2228 file_paths=file_paths, 2229 mode=mode, 2230 augment_n=augment_n, 2231 evaluate=False, 2232 ) 2233 ) 2234 if classes is None: 2235 classes = behavior_dict.values() 2236 episode = self._episodes().get_runs(episode_names[0])[0] 2237 score_tensors = task.generate_uncertainty_score( 2238 classes, 2239 augment_n, 2240 method, 2241 predicted, 2242 self._episode(episode).get_behaviors_dict(), 2243 ) 2244 intervals = self._suggest_intervals( 2245 task.dataset(mode), score_tensors, n_frames, min_length 2246 ) 2247 for k, v in intervals.items(): 2248 l = sum([x[1] - x[0] for x in v]) 2249 print(f"{k}: {len(v)} ({l})") 2250 if remove_saved_features: 2251 self._remove_stores(parameters) 2252 suggestions_path = os.path.join( 2253 self.project_path, 2254 "results", 2255 "suggestions", 2256 suggestions_name, 2257 ) 2258 if not os.path.exists(suggestions_path): 2259 os.mkdir(suggestions_path) 2260 with open( 2261 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2262 ) as f: 2263 pickle.dump(intervals, f) 2264 meta_parameters = { 2265 "suggestion_episode": episode_names, 2266 "suggestion_load_epoch": load_epochs, 2267 "suggestion_classes": classes, 2268 "min_frames_suggestion": min_length, 2269 "augment_n": augment_n, 2270 "method": method, 2271 "num_frames": n_frames, 2272 } 2273 self._save_suggestions(suggestions_name, {}, meta_parameters) 2274 print("\n") 2275 2276 def suggest_intervals_with_bald( 2277 self, 2278 suggestions_name: str, 2279 episode_name: str, 2280 load_epoch: int = None, 2281 classes: List = None, 2282 n_frames: int = 10000, 2283 num_models: int = 10, 2284 kernel_size: int = 11, 2285 min_length: int = 60, 2286 augment_n: int = 0, 2287 data_path: str = None, 2288 file_paths: Set = None, 2289 parameters_update: Dict = None, 2290 mode: str = "all", 2291 force: bool = False, 2292 remove_saved_features: bool = False, 2293 ): 2294 """Generate an active learning file based on Bayesian Active Learning by Disagreement. 2295 2296 Parameters 2297 ---------- 2298 suggestions_name : str 2299 the name of the suggestion 2300 episode_name : str 2301 the name of the episode to load the model from 2302 load_epoch : int, optional 2303 the index of the epoch to load the model from (if `None`, the last one will be used) 2304 classes : list, optional 2305 a list of classes to look at (by default all) 2306 n_frames : int, default 10000 2307 the threshold total number of frames in the suggested intervals (in the end result it will most likely 2308 be slightly larger; it will only be smaller if the algorithm fails to find enough intervals 2309 with the set parameters) 2310 num_models : int, default 10 2311 the number of dropout masks to apply 2312 kernel_size : int, default 11 2313 the size of the smoothing kernel applied to the discrete results 2314 min_length : int, default 60 2315 the minimum number of frames in one interval 2316 augment_n : int, default 0 2317 the number of augmentations to average the predictions over 2318 data_path : str, optional 2319 the path to a data folder (by default, the project data is used) 2320 file_paths : set, optional 2321 a list of file paths (by default, the project data is used) 2322 parameters_update : dict, optional 2323 a dictionary of parameter updates 2324 mode : {"test", "val", "train", "all"} 2325 the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`; 2326 by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise) 2327 force : bool, default False 2328 if `True`, existing suggestions with the same name will be overwritten 2329 remove_saved_features : bool, default False 2330 if `True`, the dataset will be deleted after the computation 2331 2332 """ 2333 self._check_suggestions_validity(suggestions_name, force=force) 2334 print(f"SUGGESTION {suggestions_name}") 2335 task, parameters, mode = self._make_task_prediction( 2336 suggestions_name, 2337 episode_name, 2338 parameters_update, 2339 load_epoch, 2340 data_path=data_path, 2341 file_paths=file_paths, 2342 mode=mode, 2343 ) 2344 if classes is None: 2345 classes = list(task.behaviors_dict().values()) 2346 score_tensors = task.generate_bald_score( 2347 classes, augment_n, num_models, kernel_size 2348 ) 2349 intervals = self._suggest_intervals( 2350 task.dataset(mode), score_tensors, n_frames, min_length 2351 ) 2352 if remove_saved_features: 2353 self._remove_stores(parameters) 2354 suggestions_path = os.path.join( 2355 self.project_path, 2356 "results", 2357 "suggestions", 2358 suggestions_name, 2359 ) 2360 if not os.path.exists(suggestions_path): 2361 os.mkdir(suggestions_path) 2362 with open( 2363 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2364 ) as f: 2365 pickle.dump(intervals, f) 2366 meta_parameters = { 2367 "suggestion_episode": episode_name, 2368 "suggestion_load_epoch": load_epoch, 2369 "suggestion_classes": classes, 2370 "min_frames_suggestion": min_length, 2371 "augment_n": augment_n, 2372 "method": f"BALD:{num_models}", 2373 "num_frames": n_frames, 2374 } 2375 self._save_suggestions(suggestions_name, {}, meta_parameters) 2376 print("\n") 2377 2378 def list_episodes( 2379 self, 2380 episode_names: List = None, 2381 value_filter: str = "", 2382 display_parameters: List = None, 2383 print_results: bool = True, 2384 ) -> pd.DataFrame: 2385 """Get a filtered pandas dataframe with episode metadata. 2386 2387 Parameters 2388 ---------- 2389 episode_names : list 2390 a list of strings of episode names 2391 value_filter : str 2392 a string of filters to apply; of this general structure: 2393 'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g. 2394 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' 2395 display_parameters : list 2396 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2397 print_results : bool, default True 2398 if True, the result will be printed to standard output 2399 2400 Returns 2401 ------- 2402 pd.DataFrame 2403 the filtered dataframe 2404 2405 """ 2406 episodes = self._episodes().list_episodes( 2407 episode_names, value_filter, display_parameters 2408 ) 2409 if print_results: 2410 print("TRAINING EPISODES") 2411 print(episodes) 2412 print("\n") 2413 return episodes 2414 2415 def list_predictions( 2416 self, 2417 episode_names: List = None, 2418 value_filter: str = "", 2419 display_parameters: List = None, 2420 print_results: bool = True, 2421 ) -> pd.DataFrame: 2422 """Get a filtered pandas dataframe with prediction metadata. 2423 2424 Parameters 2425 ---------- 2426 episode_names : list 2427 a list of strings of episode names 2428 value_filter : str 2429 a string of filters to apply; of this general structure: 2430 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2431 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2432 display_parameters : list 2433 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2434 print_results : bool, default True 2435 if True, the result will be printed to standard output 2436 2437 Returns 2438 ------- 2439 pd.DataFrame 2440 the filtered dataframe 2441 2442 """ 2443 predictions = self._predictions().list_episodes( 2444 episode_names, value_filter, display_parameters 2445 ) 2446 if print_results: 2447 print("PREDICTIONS") 2448 print(predictions) 2449 print("\n") 2450 return predictions 2451 2452 def list_suggestions( 2453 self, 2454 suggestions_names: List = None, 2455 value_filter: str = "", 2456 display_parameters: List = None, 2457 print_results: bool = True, 2458 ) -> pd.DataFrame: 2459 """Get a filtered pandas dataframe with prediction metadata. 2460 2461 Parameters 2462 ---------- 2463 suggestions_names : list 2464 a list of strings of suggestion names 2465 value_filter : str 2466 a string of filters to apply; of this general structure: 2467 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2468 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2469 display_parameters : list 2470 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2471 print_results : bool, default True 2472 if True, the result will be printed to standard output 2473 2474 Returns 2475 ------- 2476 pd.DataFrame 2477 the filtered dataframe 2478 2479 """ 2480 suggestions = self._suggestions().list_episodes( 2481 suggestions_names, value_filter, display_parameters 2482 ) 2483 if print_results: 2484 print("SUGGESTIONS") 2485 print(suggestions) 2486 print("\n") 2487 return suggestions 2488 2489 def list_searches( 2490 self, 2491 search_names: List = None, 2492 value_filter: str = "", 2493 display_parameters: List = None, 2494 print_results: bool = True, 2495 ) -> pd.DataFrame: 2496 """Get a filtered pandas dataframe with hyperparameter search metadata. 2497 2498 Parameters 2499 ---------- 2500 search_names : list 2501 a list of strings of search names 2502 value_filter : str 2503 a string of filters to apply; of this general structure: 2504 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2505 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2506 display_parameters : list 2507 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2508 print_results : bool, default True 2509 if True, the result will be printed to standard output 2510 2511 Returns 2512 ------- 2513 pd.DataFrame 2514 the filtered dataframe 2515 2516 """ 2517 searches = self._searches().list_episodes( 2518 search_names, value_filter, display_parameters 2519 ) 2520 if print_results: 2521 print("SEARCHES") 2522 print(searches) 2523 print("\n") 2524 return searches 2525 2526 def get_best_parameters( 2527 self, 2528 search_name: str, 2529 round_to_binary: List = None, 2530 ): 2531 """Get the best parameters found by a search. 2532 2533 Parameters 2534 ---------- 2535 search_name : str 2536 the name of the search 2537 round_to_binary : list, default None 2538 a list of parameters to round to binary values 2539 2540 Returns 2541 ------- 2542 best_params : dict 2543 a dictionary of the best parameters where the keys are in '{group}/{name}' format 2544 2545 """ 2546 params, model = self._searches().get_best_params( 2547 search_name, round_to_binary=round_to_binary 2548 ) 2549 params = self._update(params, {"general": {"model_name": model}}) 2550 return params 2551 2552 def list_best_parameters( 2553 self, search_name: str, print_results: bool = True 2554 ) -> Dict: 2555 """Get the raw dictionary of best parameters found by a search. 2556 2557 Parameters 2558 ---------- 2559 search_name : str 2560 the name of the search 2561 print_results : bool, default True 2562 if True, the result will be printed to standard output 2563 2564 Returns 2565 ------- 2566 best_params : dict 2567 a dictionary of the best parameters where the keys are in '{group}/{name}' format 2568 2569 """ 2570 params = self._searches().get_best_params_raw(search_name) 2571 if print_results: 2572 print(f"SEARCH RESULTS {search_name}") 2573 for k, v in params.items(): 2574 print(f"{k}: {v}") 2575 print("\n") 2576 return params 2577 2578 def plot_episodes( 2579 self, 2580 episode_names: List, 2581 metrics: List | str, 2582 modes: List | str = None, 2583 title: str = None, 2584 episode_labels: List = None, 2585 save_path: str = None, 2586 add_hlines: List = None, 2587 epoch_limits: List = None, 2588 colors: List = None, 2589 add_highpoint_hlines: bool = False, 2590 remove_box: bool = False, 2591 font_size: float = None, 2592 linewidth: float = None, 2593 return_ax: bool = False, 2594 ) -> None: 2595 """Plot episode training curves. 2596 2597 Parameters 2598 ---------- 2599 episode_names : list 2600 a list of episode names to plot; to plot to episodes in one line combine them in a list 2601 (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) 2602 metrics : list 2603 a list of metric to plot 2604 modes : list, optional 2605 a list of modes to plot ('train' and/or 'val'; `['val']` by default) 2606 title : str, optional 2607 title for the plot 2608 episode_labels : list, optional 2609 a list of strings used to label the curves (has to be the same length as episode_names) 2610 save_path : str, optional 2611 the path to save the resulting plot 2612 add_hlines : list, optional 2613 a list of float values (or (value, label) tuples) to mark with horizontal lines 2614 epoch_limits : list, optional 2615 a list of (min, max) tuples to set the x-axis limits for each episode 2616 colors: list, optional 2617 a list of matplotlib colors 2618 add_highpoint_hlines : bool, default False 2619 if `True`, horizontal lines will be added at the highest value of each episode 2620 """ 2621 2622 if isinstance(metrics, str): 2623 metrics = [metrics] 2624 if isinstance(modes, str): 2625 modes = [modes] 2626 2627 if font_size is not None: 2628 font = {"size": font_size} 2629 rc("font", **font) 2630 if modes is None: 2631 modes = ["val"] 2632 if add_hlines is None: 2633 add_hlines = [] 2634 logs = [] 2635 epochs = [] 2636 labels = [] 2637 if episode_labels is not None: 2638 assert len(episode_labels) == len(episode_names) 2639 for name_i, name in enumerate(episode_names): 2640 log_params = product(metrics, modes) 2641 for metric, mode in log_params: 2642 if episode_labels is not None: 2643 label = episode_labels[name_i] 2644 else: 2645 label = deepcopy(name) 2646 if len(modes) != 1: 2647 label += f"_{mode}" 2648 if len(metrics) != 1: 2649 label += f"_{metric}" 2650 labels.append(label) 2651 if isinstance(name, Iterable) and not isinstance(name, str): 2652 epoch_list = defaultdict(lambda: []) 2653 multi_logs = defaultdict(lambda: []) 2654 for i, n in enumerate(name): 2655 runs = self._episodes().get_runs(n) 2656 if len(runs) > 1: 2657 for run in runs: 2658 if "::" in run: 2659 index = run.split("::")[-1] 2660 else: 2661 index = run.split("#")[-1] 2662 if multi_logs[index] == []: 2663 if multi_logs["null"] is None: 2664 raise RuntimeError( 2665 "The run indices are not consistent across episodes!" 2666 ) 2667 else: 2668 multi_logs[index] += multi_logs["null"] 2669 multi_logs[index] += list( 2670 self._episode(run).get_metric_log(mode, metric) 2671 ) 2672 start = ( 2673 0 2674 if len(epoch_list[index]) == 0 2675 else epoch_list[index][-1] 2676 ) 2677 epoch_list[index] += [ 2678 x + start 2679 for x in self._episode(run).get_epoch_list(mode) 2680 ] 2681 multi_logs["null"] = None 2682 else: 2683 if len(multi_logs.keys()) > 1: 2684 raise RuntimeError( 2685 "Cannot plot a single-run episode after a multi-run episode!" 2686 ) 2687 multi_logs["null"] += list( 2688 self._episode(n).get_metric_log(mode, metric) 2689 ) 2690 start = ( 2691 0 2692 if len(epoch_list["null"]) == 0 2693 else epoch_list["null"][-1] 2694 ) 2695 epoch_list["null"] += [ 2696 x + start for x in self._episode(n).get_epoch_list(mode) 2697 ] 2698 if len(multi_logs.keys()) == 1: 2699 log = multi_logs["null"] 2700 epochs.append(epoch_list["null"]) 2701 else: 2702 log = tuple([v for k, v in multi_logs.items() if k != "null"]) 2703 epochs.append( 2704 tuple([v for k, v in epoch_list.items() if k != "null"]) 2705 ) 2706 else: 2707 runs = self._episodes().get_runs(name) 2708 if len(runs) > 1: 2709 log = [] 2710 for run in runs: 2711 tracked_metrics = self._episode(run).get_metrics() 2712 if metric in tracked_metrics: 2713 log.append( 2714 list( 2715 self._episode(run).get_metric_log(mode, metric) 2716 ) 2717 ) 2718 else: 2719 relevant = [] 2720 for m in tracked_metrics: 2721 m_split = m.split("_") 2722 if ( 2723 "_".join(m_split[:-1]) == metric 2724 and m_split[-1].isnumeric() 2725 ): 2726 relevant.append(m) 2727 if len(relevant) == 0: 2728 raise ValueError( 2729 f"The {metric} metric was not tracked at {run}" 2730 ) 2731 arr = 0 2732 for m in relevant: 2733 arr += self._episode(run).get_metric_log(mode, m) 2734 arr /= len(relevant) 2735 log.append(list(arr)) 2736 log = tuple(log) 2737 epochs.append( 2738 tuple( 2739 [ 2740 self._episode(run).get_epoch_list(mode) 2741 for run in runs 2742 ] 2743 ) 2744 ) 2745 else: 2746 tracked_metrics = self._episode(name).get_metrics() 2747 if metric in tracked_metrics: 2748 log = list(self._episode(name).get_metric_log(mode, metric)) 2749 else: 2750 relevant = [] 2751 for m in tracked_metrics: 2752 m_split = m.split("_") 2753 if ( 2754 "_".join(m_split[:-1]) == metric 2755 and m_split[-1].isnumeric() 2756 ): 2757 relevant.append(m) 2758 if len(relevant) == 0: 2759 raise ValueError( 2760 f"The {metric} metric was not tracked at {name}" 2761 ) 2762 arr = 0 2763 for m in relevant: 2764 arr += self._episode(name).get_metric_log(mode, m) 2765 arr /= len(relevant) 2766 log = list(arr) 2767 epochs.append(self._episode(name).get_epoch_list(mode)) 2768 logs.append(log) 2769 # if episode_labels is not None: 2770 # print(f'{len(episode_labels)=}, {len(logs)=}') 2771 # if len(episode_labels) != len(logs): 2772 2773 # raise ValueError( 2774 # f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of " 2775 # f"curves ({len(logs)})!" 2776 # ) 2777 # else: 2778 # labels = episode_labels 2779 if colors is None: 2780 colors = cm.rainbow(np.linspace(0, 1, len(logs))) 2781 if len(colors) != len(logs): 2782 raise ValueError( 2783 "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!" 2784 ) 2785 f, ax = plt.subplots() 2786 length = 0 2787 for log, label, color, epoch_list in zip(logs, labels, colors, epochs): 2788 if type(log) is list: 2789 if len(log) > length: 2790 length = len(log) 2791 ax.plot( 2792 epoch_list, 2793 log, 2794 label=label, 2795 color=color, 2796 ) 2797 if add_highpoint_hlines: 2798 ax.axhline(np.max(log), linestyle="dashed", color=color) 2799 else: 2800 for l, xx in zip(log, epoch_list): 2801 if len(l) > length: 2802 length = len(l) 2803 ax.plot( 2804 xx, 2805 l, 2806 color=color, 2807 alpha=0.2, 2808 ) 2809 if not all([len(x) == len(log[0]) for x in log]): 2810 warnings.warn( 2811 f"Got logs with unequal lengths in parallel runs for {label}" 2812 ) 2813 log = list(log) 2814 epoch_list = list(epoch_list) 2815 for i, x in enumerate(epoch_list): 2816 to_remove = [] 2817 for j, y in enumerate(x[1:]): 2818 if y <= x[j - 1]: 2819 y_ind = x.index(y) 2820 to_remove += list(range(y_ind, j)) 2821 epoch_list[i] = [ 2822 y for j, y in enumerate(x) if j not in to_remove 2823 ] 2824 log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove] 2825 length = min([len(x) for x in log]) 2826 for i in range(len(log)): 2827 log[i] = log[i][:length] 2828 epoch_list[i] = epoch_list[i][:length] 2829 if not all([x == epoch_list[0] for x in epoch_list]): 2830 raise RuntimeError( 2831 f"Got different epoch indices in parallel runs for {label}" 2832 ) 2833 mean = np.array(log).mean(0) 2834 ax.plot( 2835 epoch_list[0], 2836 mean, 2837 label=label, 2838 color=color, 2839 linewidth=linewidth, 2840 ) 2841 if add_highpoint_hlines: 2842 ax.axhline(np.max(mean), linestyle="dashed", color=color) 2843 for x in add_hlines: 2844 label = None 2845 if isinstance(x, Iterable): 2846 x, label = x 2847 ax.axhline(x, label=label) 2848 ax.set_xlim((0, length)) 2849 2850 ax.legend() 2851 ax.set_xlabel("epochs") 2852 if len(metrics) == 1: 2853 ax.set_ylabel(metrics[0]) 2854 else: 2855 ax.set_ylabel("value") 2856 if title is None: 2857 if len(episode_names) == 1: 2858 title = episode_names[0] 2859 elif len(metrics) == 1: 2860 title = metrics[0] 2861 if epoch_limits is not None: 2862 ax.set_xlim(epoch_limits) 2863 if title is not None: 2864 ax.set_title(title) 2865 if remove_box: 2866 ax.box(False) 2867 if return_ax: 2868 return ax 2869 if save_path is not None: 2870 plt.savefig(save_path) 2871 plt.show() 2872 2873 def update_parameters( 2874 self, 2875 parameters_update: Dict = None, 2876 load_search: str = None, 2877 load_parameters: List = None, 2878 round_to_binary: List = None, 2879 ) -> None: 2880 """Update the parameters in the project config files. 2881 2882 Parameters 2883 ---------- 2884 parameters_update : dict, optional 2885 a dictionary of parameter updates 2886 load_search : str, optional 2887 the name of hyperparameter search results to load to config 2888 load_parameters : list, optional 2889 a list of lists of string names of the parameters to load from the searches 2890 round_to_binary : list, optional 2891 a list of string names of the loaded parameters that should be rounded to the nearest power of two 2892 2893 """ 2894 keys = [ 2895 "general", 2896 "losses", 2897 "metrics", 2898 "ssl", 2899 "training", 2900 "data", 2901 ] 2902 parameters = self._read_parameters(catch_blanks=False) 2903 if parameters_update is not None: 2904 model_params = ( 2905 parameters_update.pop("model") if "model" in parameters_update else None 2906 ) 2907 feat_params = ( 2908 parameters_update.pop("features") 2909 if "features" in parameters_update 2910 else None 2911 ) 2912 aug_params = ( 2913 parameters_update.pop("augmentations") 2914 if "augmentations" in parameters_update 2915 else None 2916 ) 2917 2918 parameters = self._update(parameters, parameters_update) 2919 model_name = parameters["general"]["model_name"] 2920 parameters["model"] = self._open_yaml( 2921 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2922 ) 2923 if model_params is not None: 2924 parameters["model"] = self._update(parameters["model"], model_params) 2925 feat_name = parameters["general"]["feature_extraction"] 2926 parameters["features"] = self._open_yaml( 2927 os.path.join( 2928 self.project_path, "config", "features", f"{feat_name}.yaml" 2929 ) 2930 ) 2931 if feat_params is not None: 2932 parameters["features"] = self._update( 2933 parameters["features"], feat_params 2934 ) 2935 aug_name = options.extractor_to_transformer[ 2936 parameters["general"]["feature_extraction"] 2937 ] 2938 parameters["augmentations"] = self._open_yaml( 2939 os.path.join( 2940 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2941 ) 2942 ) 2943 if aug_params is not None: 2944 parameters["augmentations"] = self._update( 2945 parameters["augmentations"], aug_params 2946 ) 2947 if load_search is not None: 2948 parameters_update, model_name = self._searches().get_best_params( 2949 load_search, load_parameters, round_to_binary 2950 ) 2951 parameters["general"]["model_name"] = model_name 2952 parameters["model"] = self._open_yaml( 2953 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2954 ) 2955 parameters = self._update(parameters, parameters_update) 2956 for key in keys: 2957 with open( 2958 os.path.join(self.project_path, "config", f"{key}.yaml"), 2959 "w", 2960 encoding="utf-8", 2961 ) as f: 2962 YAML().dump(parameters[key], f) 2963 model_name = parameters["general"]["model_name"] 2964 model_path = os.path.join( 2965 self.project_path, "config", "model", f"{model_name}.yaml" 2966 ) 2967 with open(model_path, "w", encoding="utf-8") as f: 2968 YAML().dump(parameters["model"], f) 2969 features_name = parameters["general"]["feature_extraction"] 2970 features_path = os.path.join( 2971 self.project_path, "config", "features", f"{features_name}.yaml" 2972 ) 2973 with open(features_path, "w", encoding="utf-8") as f: 2974 YAML().dump(parameters["features"], f) 2975 aug_name = options.extractor_to_transformer[features_name] 2976 aug_path = os.path.join( 2977 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2978 ) 2979 with open(aug_path, "w", encoding="utf-8") as f: 2980 YAML().dump(parameters["augmentations"], f) 2981 2982 def get_summary( 2983 self, 2984 episode_names: list, 2985 method: str = "last", 2986 average: int = 1, 2987 metrics: List = None, 2988 return_values: bool = False, 2989 ) -> Dict: 2990 """Get a summary of episode statistics. 2991 2992 If an episode has multiple runs, the statistics will be aggregated over all of them. 2993 2994 Parameters 2995 ---------- 2996 episode_names : str 2997 the names of the episodes 2998 method : ["best", "last"] 2999 the method for choosing the epochs 3000 average : int, default 1 3001 the number of epochs to average over (for each run) 3002 metrics : list, optional 3003 a list of metrics 3004 3005 Returns 3006 ------- 3007 statistics : dict 3008 a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean 3009 and 'std' for the standard deviation 3010 3011 """ 3012 runs = [] 3013 for episode_name in episode_names: 3014 runs_ep = self._episodes().get_runs(episode_name) 3015 if len(runs_ep) == 0: 3016 raise RuntimeError( 3017 f"There is no {episode_name} episode in the project memory" 3018 ) 3019 runs += runs_ep 3020 if metrics is None: 3021 metrics = self._episode(runs[0]).get_metrics() 3022 3023 values = {m: [] for m in metrics} 3024 for run in runs: 3025 for m in metrics: 3026 log = self._episode(run).get_metric_log(mode="val", metric_name=m) 3027 if method == "best": 3028 log = sorted(log) 3029 values[m] += list(log[-average:]) 3030 elif method == "last": 3031 if len(log) == 0: 3032 episodes = self._episodes().data 3033 if average == 1 and ("results", m) in episodes.columns: 3034 values[m] += [episodes.loc[run, ("results", m)]] 3035 else: 3036 raise RuntimeError(f"Did not find {m} metric for {run} run") 3037 values[m] += list(log[-average:]) 3038 elif method.startswith("epoch"): 3039 epoch = int(method[5:]) - 1 3040 pars = self._episodes().load_parameters(run) 3041 step = int(pars["training"]["validation_interval"]) 3042 values[m] += [log[epoch // step]] 3043 else: 3044 raise ValueError( 3045 f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']" 3046 ) 3047 statistics = defaultdict(lambda: {}) 3048 for m, v in values.items(): 3049 statistics[m]["mean"] = np.mean(v) 3050 statistics[m]["std"] = np.std(v) 3051 print(f"SUMMARY {episode_names}") 3052 for m, v in statistics.items(): 3053 print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}') 3054 print("\n") 3055 3056 return (dict(statistics), values) if return_values else dict(statistics) 3057 3058 @staticmethod 3059 def remove_project(name: str, projects_path: str = None) -> None: 3060 """Remove all project files and experiment records and results. 3061 3062 Parameters 3063 ---------- 3064 name : str 3065 the name of the project to remove 3066 projects_path : str, optional 3067 the path to the projects directory (by default the home DLC2Action directory) 3068 3069 """ 3070 if projects_path is None: 3071 projects_path = os.path.join(str(Path.home()), "DLC2Action") 3072 project_path = os.path.join(projects_path, name) 3073 if os.path.exists(project_path): 3074 shutil.rmtree(project_path) 3075 3076 def remove_saved_features( 3077 self, 3078 dataset_names: List = None, 3079 exceptions: List = None, 3080 remove_active: bool = False, 3081 ) -> None: 3082 """Remove saved pre-computed dataset feature files. 3083 3084 By default, all features will be deleted. 3085 No essential information can get lost, storing them only saves time. Be careful with deleting datasets 3086 while training or inference is happening though. 3087 3088 Parameters 3089 ---------- 3090 dataset_names : list, optional 3091 a list of dataset names to delete (by default all names are added) 3092 exceptions : list, optional 3093 a list of dataset names to not be deleted 3094 remove_active : bool, default False 3095 if `False`, datasets used by unfinished episodes will not be deleted 3096 3097 """ 3098 print("Removing datasets...") 3099 if dataset_names is None: 3100 dataset_names = [] 3101 if exceptions is None: 3102 exceptions = [] 3103 if not remove_active: 3104 exceptions += self._episodes().get_active_datasets() 3105 dataset_path = os.path.join(self.project_path, "saved_datasets") 3106 if os.path.exists(dataset_path): 3107 if dataset_names == []: 3108 dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)]) 3109 3110 to_remove = [ 3111 x 3112 for x in dataset_names 3113 if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions 3114 ] 3115 if len(to_remove) > 2: 3116 to_remove = tqdm(to_remove) 3117 for dataset in to_remove: 3118 shutil.rmtree(os.path.join(dataset_path, dataset)) 3119 to_remove = [ 3120 f"{x}.pickle" 3121 for x in dataset_names 3122 if os.path.exists(os.path.join(dataset_path, f"{x}.pickle")) 3123 and x not in exceptions 3124 ] 3125 for dataset in to_remove: 3126 os.remove(os.path.join(dataset_path, dataset)) 3127 names = self._saved_datasets().dataset_names() 3128 self._saved_datasets().remove(names) 3129 print("\n") 3130 3131 def remove_extra_checkpoints( 3132 self, episode_names: List = None, exceptions: List = None 3133 ) -> None: 3134 """Remove intermediate model checkpoint files (only leave the files for the last epoch). 3135 3136 By default, all intermediate checkpoints will be deleted. 3137 Files in the model folder that are not associated with any record in the meta files are also deleted. 3138 3139 Parameters 3140 ---------- 3141 episode_names : list, optional 3142 a list of episode names to clean (by default all names are added) 3143 exceptions : list, optional 3144 a list of episode names to not clean 3145 3146 """ 3147 model_path = os.path.join(self.project_path, "results", "model") 3148 try: 3149 all_names = self._episodes().data.index 3150 except: 3151 all_names = os.listdir(model_path) 3152 if episode_names is None: 3153 episode_names = all_names 3154 if exceptions is None: 3155 exceptions = [] 3156 to_remove = [x for x in episode_names if x not in exceptions] 3157 folders = os.listdir(model_path) 3158 for folder in folders: 3159 if folder not in all_names: 3160 shutil.rmtree(os.path.join(model_path, folder)) 3161 elif folder in to_remove: 3162 files = os.listdir(os.path.join(model_path, folder)) 3163 for file in sorted(files)[:-1]: 3164 os.remove(os.path.join(model_path, folder, file)) 3165 3166 def remove_search(self, search_name: str) -> None: 3167 """Remove a hyperparameter search record. 3168 3169 Parameters 3170 ---------- 3171 search_name : str 3172 the name of the search to remove 3173 3174 """ 3175 self._searches().remove_episode(search_name) 3176 graph_path = os.path.join(self.project_path, "results", "searches", search_name) 3177 if os.path.exists(graph_path): 3178 shutil.rmtree(graph_path) 3179 3180 def remove_suggestion(self, suggestion_name: str) -> None: 3181 """Remove a suggestion record. 3182 3183 Parameters 3184 ---------- 3185 suggestion_name : str 3186 the name of the suggestion to remove 3187 3188 """ 3189 self._suggestions().remove_episode(suggestion_name) 3190 suggestion_path = os.path.join( 3191 self.project_path, "results", "suggestions", suggestion_name 3192 ) 3193 if os.path.exists(suggestion_path): 3194 shutil.rmtree(suggestion_path) 3195 3196 def remove_prediction(self, prediction_name: str) -> None: 3197 """Remove a prediction record. 3198 3199 Parameters 3200 ---------- 3201 prediction_name : str 3202 the name of the prediction to remove 3203 3204 """ 3205 self._predictions().remove_episode(prediction_name) 3206 prediction_path = self.prediction_path(prediction_name) 3207 if os.path.exists(prediction_path): 3208 shutil.rmtree(prediction_path) 3209 3210 def check_prediction_exists(self, prediction_name: str) -> str | None: 3211 """Check if a prediction exists. 3212 3213 Parameters 3214 ---------- 3215 prediction_name : str 3216 the name of the prediction to check 3217 3218 Returns 3219 ------- 3220 str | None 3221 the path to the prediction if it exists, `None` otherwise 3222 3223 """ 3224 prediction_path = self.prediction_path(prediction_name) 3225 if os.path.exists(prediction_path): 3226 return prediction_path 3227 return None 3228 3229 def remove_episode(self, episode_name: str) -> None: 3230 """Remove all model, logs and metafile records related to an episode. 3231 3232 Parameters 3233 ---------- 3234 episode_name : str 3235 the name of the episode to remove 3236 3237 """ 3238 runs = self._episodes().get_runs(episode_name) 3239 runs.append(episode_name) 3240 for run in runs: 3241 self._episodes().remove_episode(run) 3242 model_path = os.path.join(self.project_path, "results", "model", run) 3243 if os.path.exists(model_path): 3244 shutil.rmtree(model_path) 3245 log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt") 3246 if os.path.exists(log_path): 3247 os.remove(log_path) 3248 3249 @abstractmethod 3250 def _reformat_results(res: dict, classes: dict, exclusive=False): 3251 """Add classes to micro metrics in results from evaluation""" 3252 results = deepcopy(res) 3253 for key in results.keys(): 3254 if isinstance(results[key], list): 3255 if exclusive and len(classes) == len(results[key]) + 1: 3256 other_ind = list(classes.keys())[ 3257 list(classes.values()).index("other") 3258 ] 3259 classes = { 3260 (i if i < other_ind else i - 1): c 3261 for i, c in classes.items() 3262 if i != other_ind 3263 } 3264 assert len(results[key]) == len( 3265 classes 3266 ), f"Results for {key} have {len(results[key])} values, but {len(classes)} classes were provided!" 3267 results[key] = { 3268 classes[i]: float(v) for i, v in enumerate(results[key]) 3269 } 3270 return results 3271 3272 def prune_unfinished(self, exceptions: List = None) -> List: 3273 """Remove all interrupted episodes. 3274 3275 Remove all episodes that either don't have a log file or have less epochs in the log file than in 3276 the training parameters or have a model folder but not a record. Note that it can remove episodes that are 3277 currently running! 3278 3279 Parameters 3280 ---------- 3281 exceptions : list 3282 the episodes to keep even if they are interrupted 3283 3284 Returns 3285 ------- 3286 pruned : list 3287 a list of the episode names that were pruned 3288 3289 """ 3290 if exceptions is None: 3291 exceptions = [] 3292 unfinished = self._episodes().unfinished_episodes() 3293 unfinished = [x for x in unfinished if x not in exceptions] 3294 model_folders = os.listdir(os.path.join(self.project_path, "results", "model")) 3295 unfinished += [ 3296 x for x in model_folders if x not in self._episodes().list_episodes().index 3297 ] 3298 print(f"PRUNING {unfinished}") 3299 for episode_name in unfinished: 3300 self.remove_episode(episode_name) 3301 print(f"\n") 3302 return unfinished 3303 3304 def prediction_path(self, prediction_name: str) -> str: 3305 """Get the path where prediction files are saved. 3306 3307 Parameters 3308 ---------- 3309 prediction_name : str 3310 name of the prediction 3311 3312 Returns 3313 ------- 3314 prediction_path : str 3315 the file path 3316 3317 """ 3318 return os.path.join( 3319 self.project_path, "results", "predictions", f"{prediction_name}" 3320 ) 3321 3322 def suggestion_path(self, suggestion_name: str) -> str: 3323 """Get the path where suggestion files are saved. 3324 3325 Parameters 3326 ---------- 3327 suggestion_name : str 3328 name of the prediction 3329 3330 Returns 3331 ------- 3332 suggestion_path : str 3333 the file path 3334 3335 """ 3336 return os.path.join( 3337 self.project_path, "results", "suggestions", f"{suggestion_name}" 3338 ) 3339 3340 @classmethod 3341 def print_data_types(cls): 3342 """Print available data types.""" 3343 print("DATA TYPES:") 3344 for key, value in cls.data_types().items(): 3345 print(f"{key}:") 3346 print(value.__doc__) 3347 3348 @classmethod 3349 def print_annotation_types(cls): 3350 """Print available annotation types.""" 3351 print("ANNOTATION TYPES:") 3352 for key, value in cls.annotation_types(): 3353 print(f"{key}:") 3354 print(value.__doc__) 3355 3356 @staticmethod 3357 def data_types() -> List: 3358 """Get available data types. 3359 3360 Returns 3361 ------- 3362 data_types : list 3363 available data types 3364 3365 """ 3366 return options.input_stores 3367 3368 @staticmethod 3369 def annotation_types() -> List: 3370 """Get available annotation types. 3371 3372 Returns 3373 ------- 3374 list 3375 available annotation types 3376 3377 """ 3378 return options.annotation_stores 3379 3380 def _save_mask(self, file: Dict, mask_name: str): 3381 """Save a mask file. 3382 3383 Parameters 3384 ---------- 3385 file : dict 3386 the mask file data to save 3387 mask_name : str 3388 the name of the mask file 3389 3390 """ 3391 if not os.path.exists(self._mask_path()): 3392 os.mkdir(self._mask_path()) 3393 with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f: 3394 pickle.dump(file, f) 3395 3396 def _load_mask(self, mask_name: str) -> Dict: 3397 """Load a mask file. 3398 3399 Parameters 3400 ---------- 3401 mask_name : str 3402 the name of the mask file to load 3403 3404 Returns 3405 ------- 3406 mask : dict 3407 the loaded mask data 3408 3409 """ 3410 with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f: 3411 data = pickle.load(f) 3412 return data 3413 3414 def _thresholds(self) -> DecisionThresholds: 3415 """Get the decision thresholds meta object. 3416 3417 Returns 3418 ------- 3419 thresholds : DecisionThresholds 3420 the decision thresholds meta object 3421 3422 """ 3423 return DecisionThresholds(self._thresholds_path()) 3424 3425 def _episodes(self) -> SavedRuns: 3426 """Get the episodes meta object. 3427 3428 Returns 3429 ------- 3430 episodes : SavedRuns 3431 the episodes meta object 3432 3433 """ 3434 try: 3435 return SavedRuns(self._episodes_path(), self.project_path) 3436 except: 3437 self.load_metadata_backup() 3438 return SavedRuns(self._episodes_path(), self.project_path) 3439 3440 def _suggestions(self) -> Suggestions: 3441 """Get the suggestions meta object. 3442 3443 Returns 3444 ------- 3445 suggestions : Suggestions 3446 the suggestions meta object 3447 3448 """ 3449 try: 3450 return Suggestions(self._suggestions_path(), self.project_path) 3451 except: 3452 self.load_metadata_backup() 3453 return Suggestions(self._suggestions_path(), self.project_path) 3454 3455 def _predictions(self) -> SavedRuns: 3456 """Get the predictions meta object. 3457 3458 Returns 3459 ------- 3460 predictions : SavedRuns 3461 the predictions meta object 3462 3463 """ 3464 try: 3465 return SavedRuns(self._predictions_path(), self.project_path) 3466 except: 3467 self.load_metadata_backup() 3468 return SavedRuns(self._predictions_path(), self.project_path) 3469 3470 def _saved_datasets(self) -> SavedStores: 3471 """Get the datasets meta object. 3472 3473 Returns 3474 ------- 3475 datasets : SavedStores 3476 the datasets meta object 3477 3478 """ 3479 try: 3480 return SavedStores(self._saved_datasets_path()) 3481 except: 3482 self.load_metadata_backup() 3483 return SavedStores(self._saved_datasets_path()) 3484 3485 def _prediction(self, name: str) -> Run: 3486 """Get a prediction meta object. 3487 3488 Parameters 3489 ---------- 3490 name : str 3491 episode name 3492 3493 Returns 3494 ------- 3495 prediction : Run 3496 the prediction meta object 3497 3498 """ 3499 try: 3500 return Run(name, self.project_path, meta_path=self._predictions_path()) 3501 except: 3502 self.load_metadata_backup() 3503 return Run(name, self.project_path, meta_path=self._predictions_path()) 3504 3505 def _episode(self, name: str) -> Run: 3506 """Get an episode meta object. 3507 3508 Parameters 3509 ---------- 3510 name : str 3511 episode name 3512 3513 Returns 3514 ------- 3515 episode : Run 3516 the episode meta object 3517 3518 """ 3519 try: 3520 return Run(name, self.project_path, meta_path=self._episodes_path()) 3521 except: 3522 self.load_metadata_backup() 3523 return Run(name, self.project_path, meta_path=self._episodes_path()) 3524 3525 def _searches(self) -> Searches: 3526 """Get the hyperparameter search meta object. 3527 3528 Returns 3529 ------- 3530 searches : Searches 3531 the searches meta object 3532 3533 """ 3534 try: 3535 return Searches(self._searches_path(), self.project_path) 3536 except: 3537 self.load_metadata_backup() 3538 return Searches(self._searches_path(), self.project_path) 3539 3540 def _update_configs(self) -> None: 3541 """Update the project config files with newly added files and parameters. 3542 3543 This method updates the project configuration with the data path and copies 3544 any new configuration files from the original package to the project. 3545 3546 """ 3547 self.update_parameters({"data": {"data_path": self.data_path}}) 3548 folders = ["augmentations", "features", "model"] 3549 original_path = os.path.join( 3550 os.path.dirname(os.path.dirname(__file__)), "config" 3551 ) 3552 project_path = os.path.join(self.project_path, "config") 3553 filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")] 3554 for folder in folders: 3555 filenames += [ 3556 os.path.join(folder, x) 3557 for x in os.listdir(os.path.join(original_path, folder)) 3558 ] 3559 filenames.append(os.path.join("data", f"{self.data_type}.yaml")) 3560 if self.annotation_type != "none": 3561 filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml")) 3562 for file in filenames: 3563 filepath_original = os.path.join(original_path, file) 3564 if file.startswith("data") or file.startswith("annotation"): 3565 file = os.path.basename(file) 3566 filepath_project = os.path.join(project_path, file) 3567 if not os.path.exists(filepath_project): 3568 shutil.copy(filepath_original, filepath_project) 3569 else: 3570 original_pars = self._open_yaml(filepath_original) 3571 project_pars = self._open_yaml(filepath_project) 3572 to_remove = [] 3573 for key, value in project_pars.items(): 3574 if key not in original_pars: 3575 if key not in ["data_type", "annotation_type"]: 3576 to_remove.append(key) 3577 for key in to_remove: 3578 project_pars.pop(key) 3579 to_remove = [] 3580 for key, value in original_pars.items(): 3581 if key in project_pars: 3582 to_remove.append(key) 3583 for key in to_remove: 3584 original_pars.pop(key) 3585 project_pars = self._update(project_pars, original_pars) 3586 with open(filepath_project, "w", encoding="utf-8") as f: 3587 YAML().dump(project_pars, f) 3588 3589 def _update_project(self) -> None: 3590 """Update project files with the current version.""" 3591 version_file = self._version_path() 3592 ok = True 3593 if not os.path.exists(version_file): 3594 ok = False 3595 else: 3596 with open(version_file, encoding="utf-8") as f: 3597 project_version = f.read() 3598 if project_version < __version__: 3599 ok = False 3600 elif project_version > __version__: 3601 warnings.warn( 3602 f"The project expects a higher dlc2action version ({project_version}), please update!" 3603 ) 3604 if not ok: 3605 project_config_path = os.path.join(self.project_path, "config") 3606 config_path = os.path.join( 3607 os.path.dirname(os.path.dirname(__path__)), "config" 3608 ) 3609 episodes = self._episodes() 3610 folders = ["annotation", "augmentations", "data", "features", "model"] 3611 3612 project_annotation_configs = os.listdir( 3613 os.path.join(project_config_path, "annotation") 3614 ) 3615 annotation_configs = os.listdir(os.path.join(config_path, "annotation")) 3616 for ann_config in annotation_configs: 3617 if ann_config not in project_annotation_configs: 3618 shutil.copytree( 3619 os.path.join(config_path, "annotation", ann_config), 3620 os.path.join(project_config_path, "annotation", ann_config), 3621 dirs_exist_ok=True, 3622 ) 3623 else: 3624 project_pars = self._open_yaml( 3625 os.path.join(project_config_path, "annotation", ann_config) 3626 ) 3627 pars = self._open_yaml( 3628 os.path.join(config_path, "annotation", ann_config) 3629 ) 3630 new_keys = set(pars.keys()) - set(project_pars.keys()) 3631 for key in new_keys: 3632 project_pars[key] = pars[key] 3633 c = self._get_comment(pars.ca.items.get(key)) 3634 project_pars.yaml_add_eol_comment(c, key=key) 3635 episodes.update( 3636 condition=f"general/annotation_type::={ann_config}", 3637 update={f"data/{key}": pars[key]}, 3638 ) 3639 3640 def _initialize_project( 3641 self, 3642 data_type: str, 3643 annotation_type: str = None, 3644 data_path: str = None, 3645 annotation_path: str = None, 3646 copy: bool = True, 3647 ) -> None: 3648 """Initialize a new project.""" 3649 if data_type not in self.data_types(): 3650 raise ValueError( 3651 f"The {data_type} data type is not available. " 3652 f"Please choose from {self.data_types()}" 3653 ) 3654 if annotation_type not in self.annotation_types(): 3655 raise ValueError( 3656 f"The {annotation_type} annotation type is not available. " 3657 f"Please choose from {self.annotation_types()}" 3658 ) 3659 os.mkdir(self.project_path) 3660 folders = ["results", "saved_datasets", "meta", "config"] 3661 for f in folders: 3662 os.mkdir(os.path.join(self.project_path, f)) 3663 results_subfolders = [ 3664 "model", 3665 "logs", 3666 "predictions", 3667 "splits", 3668 "searches", 3669 "suggestions", 3670 ] 3671 for sf in results_subfolders: 3672 os.mkdir(os.path.join(self.project_path, "results", sf)) 3673 if data_path is not None: 3674 if copy: 3675 os.mkdir(os.path.join(self.project_path, "data")) 3676 shutil.copytree( 3677 data_path, 3678 os.path.join(self.project_path, "data"), 3679 dirs_exist_ok=True, 3680 ) 3681 data_path = os.path.join(self.project_path, "data") 3682 if annotation_path is not None: 3683 if copy: 3684 os.mkdir(os.path.join(self.project_path, "annotation")) 3685 shutil.copytree( 3686 annotation_path, 3687 os.path.join(self.project_path, "annotation"), 3688 dirs_exist_ok=True, 3689 ) 3690 annotation_path = os.path.join(self.project_path, "annotation") 3691 self._generate_config( 3692 data_type, 3693 annotation_type, 3694 data_path=data_path, 3695 annotation_path=annotation_path, 3696 ) 3697 self._generate_meta() 3698 3699 def _read_types(self) -> Tuple[str, str]: 3700 """Get data type and annotation type from existing project files.""" 3701 config_path = os.path.join(self.project_path, "config", "general.yaml") 3702 with open(config_path, encoding="utf-8") as f: 3703 pars = YAML().load(f) 3704 data_type = pars["data_type"] 3705 annotation_type = pars["annotation_type"] 3706 return annotation_type, data_type 3707 3708 def _read_paths(self) -> Tuple[str, str]: 3709 """Get data type and annotation type from existing project files.""" 3710 config_path = os.path.join(self.project_path, "config", "data.yaml") 3711 with open(config_path, encoding="utf-8") as f: 3712 pars = YAML().load(f) 3713 data_path = pars["data_path"] 3714 annotation_path = pars["annotation_path"] 3715 return annotation_path, data_path 3716 3717 def _generate_config( 3718 self, data_type: str, annotation_type: str, data_path: str, annotation_path: str 3719 ) -> None: 3720 """Initialize the config files.""" 3721 default_path = os.path.join( 3722 os.path.dirname(os.path.dirname(__file__)), "config" 3723 ) 3724 config_path = os.path.join(self.project_path, "config") 3725 files = ["losses", "metrics", "ssl", "training"] 3726 for f in files: 3727 shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path) 3728 shutil.copytree( 3729 os.path.join(default_path, "model"), os.path.join(config_path, "model") 3730 ) 3731 shutil.copytree( 3732 os.path.join(default_path, "features"), 3733 os.path.join(config_path, "features"), 3734 ) 3735 shutil.copytree( 3736 os.path.join(default_path, "augmentations"), 3737 os.path.join(config_path, "augmentations"), 3738 ) 3739 yaml = YAML() 3740 data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml") 3741 if os.path.exists(data_param_path): 3742 with open(data_param_path, encoding="utf-8") as f: 3743 data_params = yaml.load(f) 3744 if data_params is None: 3745 data_params = {} 3746 if annotation_type is None: 3747 ann_params = {} 3748 else: 3749 ann_param_path = os.path.join( 3750 default_path, "annotation", f"{annotation_type}.yaml" 3751 ) 3752 if os.path.exists(ann_param_path): 3753 ann_params = self._open_yaml(ann_param_path) 3754 elif annotation_type == "none": 3755 ann_params = {} 3756 else: 3757 raise ValueError( 3758 f"The {annotation_type} data type is not available. " 3759 f"Please choose from {BehaviorDataset.annotation_types()}" 3760 ) 3761 if ann_params is None: 3762 ann_params = {} 3763 data_params = self._update(data_params, ann_params) 3764 data_params["data_path"] = data_path 3765 data_params["annotation_path"] = annotation_path 3766 with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f: 3767 yaml.dump(data_params, f) 3768 with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f: 3769 general_params = yaml.load(f) 3770 general_params["data_type"] = data_type 3771 general_params["annotation_type"] = annotation_type 3772 with open( 3773 os.path.join(config_path, "general.yaml"), "w", encoding="utf-8" 3774 ) as f: 3775 yaml.dump(general_params, f) 3776 3777 def _generate_meta(self) -> None: 3778 """Initialize the meta files.""" 3779 config_file = os.path.join(self.project_path, "config") 3780 meta_fields = ["time"] 3781 columns = [("meta", field) for field in meta_fields] 3782 episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 3783 episodes.to_pickle(self._episodes_path()) 3784 meta_fields = ["time", "objective"] 3785 result_fields = ["best_params", "best_value"] 3786 columns = [("meta", field) for field in meta_fields] + [ 3787 ("results", field) for field in result_fields 3788 ] 3789 searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 3790 searches.to_pickle(self._searches_path()) 3791 meta_fields = ["time"] 3792 columns = [("meta", field) for field in meta_fields] 3793 predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 3794 predictions.to_pickle(self._predictions_path()) 3795 suggestions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 3796 suggestions.to_pickle(self._suggestions_path()) 3797 with open(os.path.join(config_file, "data.yaml"), encoding="utf-8") as f: 3798 data_keys = list(YAML().load(f).keys()) 3799 saved_data = pd.DataFrame(columns=data_keys) 3800 saved_data.to_pickle(self._saved_datasets_path()) 3801 pd.DataFrame().to_pickle(self._thresholds_path()) 3802 # with open(self._version_path()) as f: 3803 # f.write(__version__) 3804 3805 def _open_yaml(self, path: str) -> CommentedMap: 3806 """Load a parameter dictionary from a .yaml file.""" 3807 with open(path, encoding="utf-8") as f: 3808 data = YAML().load(f) 3809 if data is None: 3810 data = {} 3811 return data 3812 3813 def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7): 3814 """Compare nested dictionaries with 'almost equal' condition.""" 3815 ok = True 3816 if u.keys() != d.keys(): 3817 ok = False 3818 else: 3819 for k, v in u.items(): 3820 if isinstance(v, Mapping): 3821 ok = self._compare(d[k], v, allow_diff=allow_diff) 3822 else: 3823 if isinstance(v, float) or isinstance(d[k], float): 3824 if not isinstance(d[k], float) and not isinstance(d[k], int): 3825 ok = False 3826 elif not isinstance(v, float) and not isinstance(v, int): 3827 ok = False 3828 elif np.abs(v - d[k]) > allow_diff: 3829 ok = False 3830 elif v != d[k]: 3831 ok = False 3832 return ok 3833 3834 def _check_comment(self, comment_sequence: List) -> bool: 3835 """Check if a comment already exists in a ruamel.yaml comment sequence.""" 3836 if comment_sequence is None: 3837 return False 3838 c = self._get_comment(comment_sequence) 3839 if c != "": 3840 return True 3841 else: 3842 return False 3843 3844 def _get_comment(self, comment_sequence: List, strip=True) -> str: 3845 """Get the comment string from a ruamel.yaml comment sequence.""" 3846 if comment_sequence is None: 3847 return "" 3848 c = "" 3849 for cm in comment_sequence: 3850 if cm is not None: 3851 if isinstance(cm, Iterable): 3852 for c in cm: 3853 if c is not None: 3854 c = c.value 3855 break 3856 break 3857 else: 3858 c = cm.value 3859 break 3860 if strip: 3861 c = c.strip() 3862 return c 3863 3864 def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]): 3865 """Update a nested dictionary.""" 3866 if "general" in u and "model_name" in u["general"] and "model" in d: 3867 model_name = u["general"]["model_name"] 3868 if d["general"]["model_name"] != model_name: 3869 d["model"] = self._open_yaml( 3870 os.path.join( 3871 self.project_path, "config", "model", f"{model_name}.yaml" 3872 ) 3873 ) 3874 d_copied = deepcopy(d) 3875 for k, v in u.items(): 3876 if ( 3877 k in d_copied 3878 and isinstance(d_copied[k], list) 3879 and isinstance(v, Mapping) 3880 and all([isinstance(x, int) for x in v.keys()]) 3881 ): 3882 for kk, vv in v.items(): 3883 d_copied[k][kk] = vv 3884 elif ( 3885 isinstance(v, Mapping) 3886 and k in d_copied 3887 and isinstance(d_copied[k], Mapping) 3888 ): 3889 if d_copied[k] is None: 3890 d_k = CommentedMap() 3891 else: 3892 d_k = d_copied[k] 3893 d_copied[k] = self._update(d_k, v) 3894 else: 3895 d_copied[k] = v 3896 if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None: 3897 c = self._get_comment(u.ca.items.get(k), strip=False) 3898 if isinstance(d_copied, CommentedMap) and not self._check_comment( 3899 d_copied.ca.items.get(k) 3900 ): 3901 d_copied.yaml_add_eol_comment(c, key=k) 3902 return d_copied 3903 3904 def _update_with_search( 3905 self, 3906 d: Dict, 3907 search_name: str, 3908 load_parameters: list = None, 3909 round_to_binary: list = None, 3910 ): 3911 """Update a dictionary with best parameters from a hyperparameter search.""" 3912 u, _ = self._searches().get_best_params( 3913 search_name, load_parameters, round_to_binary 3914 ) 3915 return self._update(d, u) 3916 3917 def _read_parameters(self, catch_blanks=True) -> Dict: 3918 """Compose a parameter dictionary to create a task from the config files.""" 3919 config_path = os.path.join(self.project_path, "config") 3920 keys = [ 3921 "data", 3922 "general", 3923 "losses", 3924 "metrics", 3925 "ssl", 3926 "training", 3927 ] 3928 parameters = {} 3929 for key in keys: 3930 parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml")) 3931 features = parameters["general"]["feature_extraction"] 3932 parameters["features"] = self._open_yaml( 3933 os.path.join(config_path, "features", f"{features}.yaml") 3934 ) 3935 transformer = options.extractor_to_transformer[features] 3936 parameters["augmentations"] = self._open_yaml( 3937 os.path.join(config_path, "augmentations", f"{transformer}.yaml") 3938 ) 3939 model = parameters["general"]["model_name"] 3940 parameters["model"] = self._open_yaml( 3941 os.path.join(config_path, "model", f"{model}.yaml") 3942 ) 3943 # input = parameters["general"]["input"] 3944 # parameters["model"] = self._open_yaml( 3945 # os.path.join(config_path, "model", f"{model}.yaml") 3946 # ) 3947 if catch_blanks: 3948 blanks = self._get_blanks() 3949 if len(blanks) > 0: 3950 self.list_blanks() 3951 raise ValueError( 3952 f"Please fill in all the blanks before running experiments" 3953 ) 3954 return parameters 3955 3956 def set_main_parameters(self, model_name: str = None, metric_names: List = None): 3957 """Select the model and the metrics. 3958 3959 Parameters 3960 ---------- 3961 model_name : str, optional 3962 model name; run `project.help("model") to find out more 3963 metric_names : list, optional 3964 a list of metric function names; run `project.help("metrics") to find out more 3965 3966 """ 3967 pars = {"general": {}} 3968 if model_name is not None: 3969 assert model_name in options.models 3970 pars["general"]["model_name"] = model_name 3971 if metric_names is not None: 3972 for metric in metric_names: 3973 assert metric in options.metrics 3974 pars["general"]["metric_functions"] = metric_names 3975 self.update_parameters(pars) 3976 3977 def help(self, keyword: str = None): 3978 """Get information on available options. 3979 3980 Parameters 3981 ---------- 3982 keyword : str, optional 3983 the keyword for options (run without arguments to see which keywords are available) 3984 3985 """ 3986 if keyword is None: 3987 print("AVAILABLE HELP FUNCTIONS:") 3988 print("- Try running `project.help(keyword)` with the following keywords:") 3989 print(" - model: to get more information on available models,") 3990 print( 3991 " - features: to get more information on available feature extraction modes," 3992 ) 3993 print( 3994 " - partition_method: to get more information on available train/test/val partitioning methods," 3995 ) 3996 print(" - metrics: to see a list of available metric functions.") 3997 print(" - data: to see help for expected data structure") 3998 print( 3999 "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in." 4000 ) 4001 print( 4002 "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify" 4003 ) 4004 print( 4005 f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)." 4006 ) 4007 elif keyword == "model": 4008 print("MODELS:") 4009 for key, model in options.models.items(): 4010 print(f"{key}:") 4011 print(model.__doc__) 4012 elif keyword == "features": 4013 print("FEATURE EXTRACTORS:") 4014 for key, extractor in options.feature_extractors.items(): 4015 print(f"{key}:") 4016 print(extractor.__doc__) 4017 elif keyword == "partition_method": 4018 print("PARTITION METHODS:") 4019 print( 4020 BehaviorDataset.partition_train_test_val.__doc__.split( 4021 "The partitioning method:" 4022 )[1].split("val_frac :")[0] 4023 ) 4024 elif keyword == "metrics": 4025 print("METRICS:") 4026 for key, metric in options.metrics.items(): 4027 print(f"{key}:") 4028 print(metric.__doc__) 4029 elif keyword == "data": 4030 print("DATA:") 4031 print(f"Video data: {self.data_type}") 4032 print(options.input_stores[self.data_type].__doc__) 4033 print(f"Annotation data: {self.annotation_type}") 4034 print(options.annotation_stores[self.annotation_type].__doc__) 4035 print( 4036 "Annotation path and data path don't have to be separate, you can keep everything in one folder." 4037 ) 4038 else: 4039 raise ValueError(f"The {keyword} keyword is not recognized") 4040 print("\n") 4041 4042 def _process_value(self, value): 4043 """Process a configuration value for display. 4044 4045 Parameters 4046 ---------- 4047 value : any 4048 the value to process 4049 4050 Returns 4051 ------- 4052 processed_value : any 4053 the processed value 4054 4055 """ 4056 if isinstance(value, str): 4057 value = f'"{value}"' 4058 elif isinstance(value, CommentedSet): 4059 value = {x for x in value} 4060 return value 4061 4062 def _get_blanks(self): 4063 """Get a list of blank (unset) parameters in the configuration. 4064 4065 Returns 4066 ------- 4067 caught : list 4068 a list of parameter keys that have blank values 4069 4070 """ 4071 caught = [] 4072 parameters = self._read_parameters(catch_blanks=False) 4073 for big_key, big_value in parameters.items(): 4074 for key, value in big_value.items(): 4075 if value == "???": 4076 caught.append( 4077 (big_key, key, self._get_comment(big_value.ca.items.get(key))) 4078 ) 4079 return caught 4080 4081 def list_blanks(self, blanks=None): 4082 """List parameters that need to be filled in. 4083 4084 Parameters 4085 ---------- 4086 blanks : list, optional 4087 a list of the parameters to list, if already known 4088 4089 """ 4090 if blanks is None: 4091 blanks = self._get_blanks() 4092 if len(blanks) > 0: 4093 to_update = defaultdict(lambda: []) 4094 for b, k, c in blanks: 4095 to_update[b].append((k, c)) 4096 print("Before running experiments, please update all the blanks.") 4097 print("To do that, you can run this.") 4098 print("--------------------------------------------------------") 4099 print(f"project.update_parameters(") 4100 print(f" {{") 4101 for big_key, keys in to_update.items(): 4102 print(f' "{big_key}": {{') 4103 for key, comment in keys: 4104 print(f' "{key}": ..., {comment}') 4105 print(f" }}") 4106 print(f" }}") 4107 print(")") 4108 print("--------------------------------------------------------") 4109 print("Replace ... with relevant values.") 4110 else: 4111 print("There is no blanks left!") 4112 4113 def list_basic_parameters( 4114 self, 4115 ): 4116 """Get a list of most relevant parameters and code to modify them.""" 4117 parameters = self._read_parameters() 4118 print("BASIC PARAMETERS:") 4119 model_name = parameters["general"]["model_name"] 4120 metric_names = parameters["general"]["metric_functions"] 4121 loss_name = parameters["general"]["loss_function"] 4122 feature_extraction = parameters["general"]["feature_extraction"] 4123 print("Here is a list of current parameters.") 4124 print( 4125 "You can copy this code, change the parameters you want to set and run it to update the project config." 4126 ) 4127 print("--------------------------------------------------------") 4128 print("project.update_parameters(") 4129 print(" {") 4130 for group in ["general", "data", "training"]: 4131 print(f' "{group}": {{') 4132 for key in options.basic_parameters[group]: 4133 if key in parameters[group]: 4134 print( 4135 f' "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}' 4136 ) 4137 print(" },") 4138 print(' "losses": {') 4139 print(f' "{loss_name}": {{') 4140 for key in options.basic_parameters["losses"][loss_name]: 4141 if key in parameters["losses"][loss_name]: 4142 print( 4143 f' "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}' 4144 ) 4145 print(" },") 4146 print(" },") 4147 print(' "metrics": {') 4148 for metric in metric_names: 4149 print(f' "{metric}": {{') 4150 for key in parameters["metrics"][metric]: 4151 print( 4152 f' "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}' 4153 ) 4154 print(" },") 4155 print(" },") 4156 print(' "model": {') 4157 for key in options.basic_parameters["model"][model_name]: 4158 if key in parameters["model"]: 4159 print( 4160 f' "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}' 4161 ) 4162 4163 print(" },") 4164 print(' "features": {') 4165 for key in options.basic_parameters["features"][feature_extraction]: 4166 if key in parameters["features"]: 4167 print( 4168 f' "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}' 4169 ) 4170 4171 print(" },") 4172 print(' "augmentations": {') 4173 for key in options.basic_parameters["augmentations"][feature_extraction]: 4174 if key in parameters["augmentations"]: 4175 print( 4176 f' "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}' 4177 ) 4178 print(" },") 4179 print(" },") 4180 print(")") 4181 print("--------------------------------------------------------") 4182 print("\n") 4183 4184 def _create_record( 4185 self, 4186 episode_name: str, 4187 behaviors_dict: Dict, 4188 load_episode: str = None, 4189 parameters_update: Dict = None, 4190 task: TaskDispatcher = None, 4191 load_epoch: int = None, 4192 load_search: str = None, 4193 load_parameters: list = None, 4194 round_to_binary: list = None, 4195 load_strict: bool = True, 4196 n_seeds: int = 1, 4197 ) -> TaskDispatcher: 4198 """Create a meta data episode record.""" 4199 if episode_name in self._episodes().data.index: 4200 return 4201 if type(n_seeds) is not int or n_seeds < 1: 4202 raise ValueError( 4203 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 4204 ) 4205 if parameters_update is None: 4206 parameters_update = {} 4207 parameters = self._read_parameters() 4208 parameters = self._update(parameters, parameters_update) 4209 if load_search is not None: 4210 parameters = self._update_with_search( 4211 parameters, load_search, load_parameters, round_to_binary 4212 ) 4213 parameters = self._fill( 4214 parameters, 4215 episode_name, 4216 load_episode, 4217 load_epoch=load_epoch, 4218 only_load_model=True, 4219 load_strict=load_strict, 4220 continuing=True, 4221 ) 4222 self._save_episode(episode_name, parameters, behaviors_dict) 4223 return task 4224 4225 def _save_thresholds( 4226 self, 4227 episode_names: List, 4228 metric_name: str, 4229 parameters: Dict, 4230 thresholds: List, 4231 load_epochs: List, 4232 ): 4233 """Save optimal decision thresholds in the meta records.""" 4234 metric_parameters = parameters["metrics"][metric_name] 4235 self._thresholds().save_thresholds( 4236 episode_names, load_epochs, metric_name, metric_parameters, thresholds 4237 ) 4238 4239 def _save_episode( 4240 self, 4241 episode_name: str, 4242 parameters: Dict, 4243 behaviors_dict: Dict, 4244 suppress_validation: bool = False, 4245 training_time: str = None, 4246 norm_stats: Dict = None, 4247 ) -> None: 4248 """Save an episode in the meta files.""" 4249 try: 4250 split_info = self._split_info_from_filename( 4251 parameters["training"]["split_path"] 4252 ) 4253 parameters["training"]["partition_method"] = split_info["partition_method"] 4254 except: 4255 pass 4256 if norm_stats is not None: 4257 norm_stats = dict(norm_stats) 4258 parameters["training"]["stats"] = norm_stats 4259 self._episodes().save_episode( 4260 episode_name, 4261 parameters, 4262 behaviors_dict, 4263 suppress_validation=suppress_validation, 4264 training_time=training_time, 4265 ) 4266 4267 def _save_suggestions( 4268 self, suggestions_name: str, parameters: Dict, meta_parameters: Dict 4269 ) -> None: 4270 """Save a suggestion in the meta files.""" 4271 self._suggestions().save_suggestion( 4272 suggestions_name, parameters, meta_parameters 4273 ) 4274 4275 def _update_episode_results( 4276 self, 4277 episode_name: str, 4278 logs: Tuple, 4279 training_time: str = None, 4280 ) -> None: 4281 """Save the results of a run in the meta files.""" 4282 self._episodes().update_episode_results(episode_name, logs, training_time) 4283 4284 def _save_prediction( 4285 self, 4286 prediction_name: str, 4287 predicted: Dict[str, Dict], 4288 parameters: Dict, 4289 task: TaskDispatcher, 4290 mode: str = "test", 4291 embedding: bool = False, 4292 inference_time: str = None, 4293 behavior_dict: List[Dict[str, Any]] = None, 4294 ) -> None: 4295 """Save a prediction in the meta files.""" 4296 4297 folder = self.prediction_path(prediction_name) 4298 os.mkdir(folder) 4299 for video_id, prediction in predicted.items(): 4300 with open( 4301 os.path.join( 4302 folder, video_id + f"_{prediction_name}_prediction.pickle" 4303 ), 4304 "wb", 4305 ) as f: 4306 prediction["min_frames"], prediction["max_frames"] = task.dataset( 4307 mode 4308 ).get_min_max_frames(video_id) 4309 prediction["classes"] = behavior_dict 4310 pickle.dump(prediction, f) 4311 4312 parameters = self._update( 4313 parameters, 4314 {"meta": {"embedding": embedding, "inference_time": inference_time}}, 4315 ) 4316 self._predictions().save_episode( 4317 prediction_name, parameters, task.behaviors_dict() 4318 ) 4319 4320 def _save_search( 4321 self, 4322 search_name: str, 4323 parameters: Dict, 4324 n_trials: int, 4325 best_params: Dict, 4326 best_value: float, 4327 metric: str, 4328 search_space: Dict, 4329 ) -> None: 4330 """Save a hyperparameter search in the meta files.""" 4331 self._searches().save_search( 4332 search_name, 4333 parameters, 4334 n_trials, 4335 best_params, 4336 best_value, 4337 metric, 4338 search_space, 4339 ) 4340 4341 def _save_stores(self, parameters: Dict) -> None: 4342 """Save a pickled dataset in the meta files.""" 4343 name = os.path.basename(parameters["data"]["feature_save_path"]) 4344 self._saved_datasets().save_store(name, self._get_data_pars(parameters)) 4345 self.create_metadata_backup() 4346 4347 def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None: 4348 """Remove the pre-computed features folder.""" 4349 name = os.path.basename(parameters["data"]["feature_save_path"]) 4350 if remove_active or name not in self._episodes().get_active_datasets(): 4351 self.remove_saved_features([name]) 4352 4353 def _check_episode_validity( 4354 self, episode_name: str, allow_doublecolon: bool = False, force: bool = False 4355 ) -> None: 4356 """Check whether the episode name is valid.""" 4357 if episode_name.startswith("_"): 4358 raise ValueError( 4359 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 4360 ) 4361 elif "." in episode_name: 4362 raise ValueError("Names containing '.' cannot be used!") 4363 if not allow_doublecolon and "#" in episode_name: 4364 raise ValueError( 4365 "Names containing '#' are reserved by dlc2action and cannot be used!" 4366 ) 4367 if "::" in episode_name: 4368 raise ValueError( 4369 "Names containing '::' are reserved by dlc2action and cannot be used!" 4370 ) 4371 if force: 4372 self.remove_episode(episode_name) 4373 elif not self._episodes().check_name_validity(episode_name): 4374 raise ValueError( 4375 f"The {episode_name} name is already taken! Set force=True to overwrite." 4376 ) 4377 4378 def _check_search_validity(self, search_name: str, force: bool = False) -> None: 4379 """Check whether the search name is valid.""" 4380 if search_name.startswith("_"): 4381 raise ValueError( 4382 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 4383 ) 4384 elif "." in search_name: 4385 raise ValueError("Names containing '.' cannot be used!") 4386 if force: 4387 self.remove_search(search_name) 4388 elif not self._searches().check_name_validity(search_name): 4389 raise ValueError(f"The {search_name} name is already taken!") 4390 4391 def _check_prediction_validity( 4392 self, prediction_name: str, force: bool = False 4393 ) -> None: 4394 """Check whether the prediction name is valid.""" 4395 if prediction_name.startswith("_"): 4396 raise ValueError( 4397 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 4398 ) 4399 elif "." in prediction_name: 4400 raise ValueError("Names containing '.' cannot be used!") 4401 if force: 4402 self.remove_prediction(prediction_name) 4403 elif not self._predictions().check_name_validity(prediction_name): 4404 raise ValueError(f"The {prediction_name} name is already taken!") 4405 4406 def _check_suggestions_validity( 4407 self, suggestions_name: str, force: bool = False 4408 ) -> None: 4409 """Check whether the suggestions name is valid.""" 4410 if suggestions_name.startswith("_"): 4411 raise ValueError( 4412 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 4413 ) 4414 elif "." in suggestions_name: 4415 raise ValueError("Names containing '.' cannot be used!") 4416 if force: 4417 self.remove_suggestion(suggestions_name) 4418 elif not self._suggestions().check_name_validity(suggestions_name): 4419 raise ValueError(f"The {suggestions_name} name is already taken!") 4420 4421 def _training_time(self, episode_name: str) -> int: 4422 """Get the training time of an episode in seconds.""" 4423 return self._episode(episode_name).training_time() 4424 4425 def _mask_path(self) -> str: 4426 """Get the path to the masks folder. 4427 4428 Returns 4429 ------- 4430 path : str 4431 the path to the masks folder 4432 4433 """ 4434 return os.path.join(self.project_path, "results", "masks") 4435 4436 def _thresholds_path(self) -> str: 4437 """Get the path to the thresholds meta file. 4438 4439 Returns 4440 ------- 4441 path : str 4442 the path to the thresholds meta file 4443 4444 """ 4445 return os.path.join(self.project_path, "meta", "thresholds.pickle") 4446 4447 def _episodes_path(self) -> str: 4448 """Get the path to the episodes meta file. 4449 4450 Returns 4451 ------- 4452 path : str 4453 the path to the episodes meta file 4454 4455 """ 4456 return os.path.join(self.project_path, "meta", "episodes.pickle") 4457 4458 def _suggestions_path(self) -> str: 4459 """Get the path to the suggestions meta file. 4460 4461 Returns 4462 ------- 4463 path : str 4464 the path to the suggestions meta file 4465 4466 """ 4467 return os.path.join(self.project_path, "meta", "suggestions.pickle") 4468 4469 def _saved_datasets_path(self) -> str: 4470 """Get the path to the datasets meta file. 4471 4472 Returns 4473 ------- 4474 path : str 4475 the path to the datasets meta file 4476 4477 """ 4478 return os.path.join(self.project_path, "meta", "saved_datasets.pickle") 4479 4480 def _predictions_path(self) -> str: 4481 """Get the path to the predictions meta file. 4482 4483 Returns 4484 ------- 4485 path : str 4486 the path to the predictions meta file 4487 4488 """ 4489 return os.path.join(self.project_path, "meta", "predictions.pickle") 4490 4491 def _dataset_store_path(self, name: str) -> str: 4492 """Get the path to a specific pickled dataset. 4493 4494 Parameters 4495 ---------- 4496 name : str 4497 the name of the dataset 4498 4499 Returns 4500 ------- 4501 path : str 4502 the path to the dataset file 4503 4504 """ 4505 return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle") 4506 4507 def _al_points_path(self, suggestions_name: str) -> str: 4508 """Get the path to an active learning intervals file. 4509 4510 Parameters 4511 ---------- 4512 suggestions_name : str 4513 the name of the suggestions 4514 4515 Returns 4516 ------- 4517 path : str 4518 the path to the active learning points file 4519 4520 """ 4521 path = os.path.join( 4522 self.project_path, 4523 "results", 4524 "suggestions", 4525 suggestions_name, 4526 f"{suggestions_name}_al_points.pickle", 4527 ) 4528 return path 4529 4530 def _suggestion_path(self, v_id: str, suggestions_name: str) -> str: 4531 """Get the path to a suggestion file. 4532 4533 Parameters 4534 ---------- 4535 v_id : str 4536 the video ID 4537 suggestions_name : str 4538 the name of the suggestions 4539 4540 Returns 4541 ------- 4542 path : str 4543 the path to the suggestion file 4544 4545 """ 4546 path = os.path.join( 4547 self.project_path, 4548 "results", 4549 "suggestions", 4550 suggestions_name, 4551 f"{v_id}_suggestion.pickle", 4552 ) 4553 return path 4554 4555 def _searches_path(self) -> str: 4556 """Get the path to the hyperparameter search meta file. 4557 4558 Returns 4559 ------- 4560 path : str 4561 the path to the searches meta file 4562 4563 """ 4564 return os.path.join(self.project_path, "meta", "searches.pickle") 4565 4566 def _search_path(self, name: str) -> str: 4567 """Get the default path to the graph folder for a specific hyperparameter search. 4568 4569 Parameters 4570 ---------- 4571 name : str 4572 the name of the search 4573 4574 Returns 4575 ------- 4576 path : str 4577 the path to the search folder 4578 4579 """ 4580 return os.path.join(self.project_path, "results", "searches", name) 4581 4582 def _version_path(self) -> str: 4583 """Get the path to the version file. 4584 4585 Returns 4586 ------- 4587 path : str 4588 the path to the version file 4589 4590 """ 4591 return os.path.join(self.project_path, "meta", "version.txt") 4592 4593 def _default_split_file(self, split_info: Dict) -> Optional[str]: 4594 """Generate a path to a split file from split parameters. 4595 4596 Parameters 4597 ---------- 4598 split_info : dict 4599 the split parameters dictionary 4600 4601 Returns 4602 ------- 4603 split_file_path : str or None 4604 the path to the split file, or None if not applicable 4605 4606 """ 4607 if split_info["partition_method"].startswith("time"): 4608 return None 4609 val_frac = split_info["val_frac"] 4610 test_frac = split_info["test_frac"] 4611 split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}' 4612 if not split_info["only_load_annotated"]: 4613 split_name += "_all" 4614 split_name += ".txt" 4615 return os.path.join(self.project_path, "results", "splits", split_name) 4616 4617 def _split_info_from_filename(self, split_name: str) -> Dict: 4618 """Get split parameters from default path to a split file. 4619 4620 Parameters 4621 ---------- 4622 split_name : str 4623 the name/path of the split file 4624 4625 Returns 4626 ------- 4627 split_info : dict 4628 the split parameters dictionary 4629 4630 """ 4631 if split_name is None: 4632 return {} 4633 try: 4634 name = os.path.basename(split_name)[:-4] 4635 split = name.split("_") 4636 if len(split) == 6: 4637 only_load_annotated = False 4638 else: 4639 only_load_annotated = True 4640 len_segment = int(split[3][3:]) 4641 overlap = float(split[4][7:]) 4642 if overlap > 1: 4643 overlap = int(overlap) 4644 method, val, test = split[:3] 4645 val = float(val[3:-1]) / 100 4646 test = float(test[4:-1]) / 100 4647 return { 4648 "partition_method": method, 4649 "val_frac": val, 4650 "test_frac": test, 4651 "only_load_annotated": only_load_annotated, 4652 "len_segment": len_segment, 4653 "overlap": overlap, 4654 } 4655 except: 4656 return {"partition_method": "file"} 4657 4658 def _fill( 4659 self, 4660 parameters: Dict, 4661 episode_name: str, 4662 load_experiment: str = None, 4663 load_epoch: int = None, 4664 load_strict: bool = True, 4665 only_load_model: bool = False, 4666 continuing: bool = False, 4667 enforce_split_parameters: bool = False, 4668 ) -> Dict: 4669 """Update the parameters from the config files with project specific information. 4670 4671 Fill in the constant file path parameters and generate a unique log file and a model folder. 4672 Fill in the split file if the same split has been run before in the project and change partition method to 4673 from_file. 4674 Fill in saved data path if a dataset with the same data parameters already exists in the project. 4675 If load_experiment is not None, fill in the checkpoint path as well. 4676 The only_load_model training parameter is defined by the corresponding argument. 4677 If continuing is True, new files are not created and all information is loaded from load_experiment. 4678 If prediction is True, log and model files are not created. 4679 The enforce_split_parameters parameter is used to resolve conflicts 4680 between split file path and split parameters when they arise. 4681 4682 Parameters 4683 ---------- 4684 parameters : dict 4685 the parameters dictionary to update 4686 episode_name : str 4687 the name of the episode 4688 load_experiment : str, optional 4689 the name of the experiment to load from 4690 load_epoch : int, optional 4691 the epoch to load (by default the last one) 4692 load_strict : bool, default True 4693 if `True`, strict loading is enforced 4694 only_load_model : bool, default False 4695 if `True`, only the model is loaded 4696 continuing : bool, default False 4697 if `True`, continues from existing files 4698 enforce_split_parameters : bool, default False 4699 if `True`, split parameters are enforced 4700 4701 Returns 4702 ------- 4703 parameters : dict 4704 the updated parameters dictionary 4705 4706 """ 4707 pars = deepcopy(parameters) 4708 if episode_name == "_": 4709 self.remove_episode("_") 4710 log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt") 4711 model_save_path = os.path.join( 4712 self.project_path, "results", "model", episode_name 4713 ) 4714 if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)): 4715 raise ValueError( 4716 f"The {episode_name} episode name is already in use! Set force=True to overwrite." 4717 ) 4718 keys = ["val_frac", "test_frac", "partition_method"] 4719 if "len_segment" not in pars["general"] and "len_segment" in pars["data"]: 4720 pars["general"]["len_segment"] = pars["data"]["len_segment"] 4721 if "overlap" not in pars["general"] and "overlap" in pars["data"]: 4722 pars["general"]["overlap"] = pars["data"]["overlap"] 4723 if "len_segment" in pars["data"]: 4724 pars["data"].pop("len_segment") 4725 if "overlap" in pars["data"]: 4726 pars["data"].pop("overlap") 4727 split_info = {k: pars["training"][k] for k in keys} 4728 split_info["only_load_annotated"] = pars["general"]["only_load_annotated"] 4729 split_info["len_segment"] = pars["general"]["len_segment"] 4730 split_info["overlap"] = pars["general"]["overlap"] 4731 pars["training"]["log_file"] = log 4732 if not os.path.exists(model_save_path): 4733 os.mkdir(model_save_path) 4734 pars["training"]["model_save_path"] = model_save_path 4735 if load_experiment is not None: 4736 if load_experiment not in self._episodes().data.index: 4737 raise ValueError(f"The {load_experiment} episode does not exist!") 4738 old_episode = self._episode(load_experiment) 4739 old_file = old_episode.split_file() 4740 old_info = self._split_info_from_filename(old_file) 4741 if len(old_info) == 0: 4742 old_info = old_episode.split_info() 4743 if enforce_split_parameters: 4744 if split_info["partition_method"] != "file": 4745 pars["training"]["split_path"] = self._default_split_file( 4746 split_info 4747 ) 4748 else: 4749 equal = True 4750 if old_info["partition_method"] != split_info["partition_method"]: 4751 equal = False 4752 if old_info["partition_method"] != "file": 4753 if ( 4754 old_info["val_frac"] != split_info["val_frac"] 4755 or old_info["test_frac"] != split_info["test_frac"] 4756 ): 4757 equal = False 4758 if not continuing and not equal: 4759 warnings.warn( 4760 f"The partitioning parameters in the loaded experiment ({old_info}) " 4761 f"are not equal to the current partitioning parameters ({split_info}). " 4762 f"The current parameters are replaced." 4763 ) 4764 pars["training"]["split_path"] = old_file 4765 for k, v in old_info.items(): 4766 pars["training"][k] = v 4767 pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch) 4768 pars["training"]["load_strict"] = load_strict 4769 else: 4770 pars["training"]["checkpoint_path"] = None 4771 if pars["training"]["partition_method"] == "file": 4772 if ( 4773 "split_path" not in pars["training"] 4774 or pars["training"]["split_path"] is None 4775 ): 4776 raise ValueError( 4777 "The partition_method parameter is set to file but the " 4778 "split_path parameter is not set!" 4779 ) 4780 elif not os.path.exists(pars["training"]["split_path"]): 4781 raise ValueError( 4782 f'The {pars["training"]["split_path"]} split file does not exist' 4783 ) 4784 else: 4785 pars["training"]["split_path"] = self._default_split_file(split_info) 4786 pars["training"]["only_load_model"] = only_load_model 4787 pars["data"]["saved_data_path"] = None 4788 pars["data"]["feature_save_path"] = None 4789 pars_data_copy = self._get_data_pars(pars) 4790 saved_data_name = self._saved_datasets().find_name(pars_data_copy) 4791 if saved_data_name is not None: 4792 pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name) 4793 pars["data"]["feature_save_path"] = self._dataset_store_path( 4794 saved_data_name 4795 ).split(".")[0] 4796 else: 4797 dataset_path = self._dataset_store_path(episode_name) 4798 if os.path.exists(dataset_path): 4799 name, ext = dataset_path.split(".") 4800 i = 0 4801 while os.path.exists(f"{name}_{i}.{ext}"): 4802 i += 1 4803 dataset_path = f"{name}_{i}.{ext}" 4804 pars["data"]["saved_data_path"] = dataset_path 4805 pars["data"]["feature_save_path"] = dataset_path.split(".")[0] 4806 split_split = pars["training"]["partition_method"].split(":") 4807 random = True 4808 for partition_method in options.partition_methods["fixed"]: 4809 method_split = partition_method.split(":") 4810 if len(split_split) != len(method_split): 4811 continue 4812 equal = True 4813 for x, y in zip(split_split, method_split): 4814 if y.startswith("{"): 4815 continue 4816 if x != y: 4817 equal = False 4818 break 4819 if equal: 4820 random = False 4821 break 4822 if random and os.path.exists(pars["training"]["split_path"]): 4823 pars["training"]["partition_method"] = "file" 4824 pars["general"]["save_dataset"] = True 4825 # Check len_segment for c2f models 4826 if pars["general"]["model_name"].startswith("c2f"): 4827 if int(pars["general"]["len_segment"]) < 512: 4828 raise ValueError( 4829 "The segment length should be higher than 512 when using one of the C2F models" 4830 ) 4831 return pars 4832 4833 def _get_data_pars(self, pars: Dict) -> Dict: 4834 """Get a complete description of the data from a general parameters dictionary. 4835 4836 Parameters 4837 ---------- 4838 pars : dict 4839 the general parameters dictionary 4840 4841 Returns 4842 ------- 4843 pars_data : dict 4844 the complete data parameters dictionary 4845 4846 """ 4847 pars_data_copy = deepcopy(pars["data"]) 4848 for par in [ 4849 "only_load_annotated", 4850 "exclusive", 4851 "feature_extraction", 4852 "ignored_clips", 4853 "len_segment", 4854 "overlap", 4855 ]: 4856 pars_data_copy[par] = pars["general"].get(par, None) 4857 pars_data_copy.update(pars["features"]) 4858 return pars_data_copy 4859 4860 def _make_al_points_from_suggestions( 4861 self, 4862 suggestions_name: str, 4863 task: TaskDispatcher, 4864 predicted_classes: Dict, 4865 background_threshold: Optional[float], 4866 visibility_min_score: float, 4867 visibility_min_frac: float, 4868 num_behaviors: int, 4869 ): 4870 valleys = [] 4871 if background_threshold is not None: 4872 for i in range(num_behaviors): 4873 print(f"generating background for behavior {i}...") 4874 valleys.append( 4875 task.dataset("train").find_valleys( 4876 predicted_classes, 4877 threshold=background_threshold, 4878 visibility_min_score=visibility_min_score, 4879 visibility_min_frac=visibility_min_frac, 4880 main_class=i, 4881 low=True, 4882 cut_annotated=True, 4883 min_frames=1, 4884 ) 4885 ) 4886 valleys = task.dataset("train").valleys_intersection(valleys) 4887 folder = os.path.join( 4888 self.project_path, "results", "suggestions", suggestions_name 4889 ) 4890 os.makedirs(os.path.dirname(folder), exist_ok=True) 4891 res = {} 4892 for file in os.listdir(folder): 4893 video_id = file.split("_suggestion.p")[0] 4894 res[video_id] = [] 4895 with open(os.path.join(folder, file), "rb") as f: 4896 data = pickle.load(f) 4897 for clip_id, ind_list in zip(data[2], data[3]): 4898 max_len = max( 4899 [ 4900 max([x[1] for x in cat_list]) if len(cat_list) > 0 else 0 4901 for cat_list in ind_list 4902 ] 4903 ) 4904 if max_len == 0: 4905 continue 4906 arr = torch.zeros(max_len) 4907 for cat_list in ind_list: 4908 for start, end, amb in cat_list: 4909 arr[start:end] = 1 4910 if video_id in valleys: 4911 for start, end, clip in valleys[video_id]: 4912 if clip == clip_id: 4913 arr[start:end] = 1 4914 output, indices, counts = torch.unique_consecutive( 4915 arr > 0, return_inverse=True, return_counts=True 4916 ) 4917 long_indices = torch.where(output)[0] 4918 res[video_id] += [ 4919 ( 4920 (indices == i).nonzero(as_tuple=True)[0][0].item(), 4921 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 4922 clip_id, 4923 ) 4924 for i in long_indices 4925 ] 4926 return res 4927 4928 def _make_al_points( 4929 self, 4930 task: TaskDispatcher, 4931 predicted_error: torch.Tensor, 4932 predicted_classes: torch.Tensor, 4933 exclude_classes: List, 4934 exclude_threshold: List, 4935 exclude_threshold_diff: List, 4936 exclude_hysteresis: List, 4937 include_classes: List, 4938 include_threshold: List, 4939 include_threshold_diff: List, 4940 include_hysteresis: List, 4941 error_episode: str = None, 4942 error_class: str = None, 4943 suggestion_episodes: List = None, 4944 error_threshold: float = 0.5, 4945 error_threshold_diff: float = 0.1, 4946 error_hysteresis: bool = False, 4947 min_frames_al: int = 30, 4948 visibility_min_score: float = 5, 4949 visibility_min_frac: float = 0.7, 4950 ) -> Dict: 4951 """Generate an active learning file.""" 4952 if len(exclude_classes) > 0 or len(include_classes) > 0: 4953 valleys = [] 4954 included = None 4955 excluded = None 4956 for class_name, thr, thr_diff, hysteresis in zip( 4957 exclude_classes, 4958 exclude_threshold, 4959 exclude_threshold_diff, 4960 exclude_hysteresis, 4961 ): 4962 episode = self._episodes().get_runs(suggestion_episodes[0])[0] 4963 class_index = self._episode(episode).get_class_ind(class_name) 4964 valleys.append( 4965 task.dataset("train").find_valleys( 4966 predicted_classes, 4967 predicted_error=predicted_error, 4968 min_frames=min_frames_al, 4969 threshold=thr, 4970 visibility_min_score=visibility_min_score, 4971 visibility_min_frac=visibility_min_frac, 4972 error_threshold=error_threshold, 4973 main_class=class_index, 4974 low=True, 4975 threshold_diff=thr_diff, 4976 min_frames_error=min_frames_al, 4977 hysteresis=hysteresis, 4978 ) 4979 ) 4980 if len(valleys) > 0: 4981 included = task.dataset("train").valleys_union(valleys) 4982 valleys = [] 4983 for class_name, thr, thr_diff, hysteresis in zip( 4984 include_classes, 4985 include_threshold, 4986 include_threshold_diff, 4987 include_hysteresis, 4988 ): 4989 episode = self._episodes().get_runs(suggestion_episodes[0])[0] 4990 class_index = self._episode(episode).get_class_ind(class_name) 4991 valleys.append( 4992 task.dataset("train").find_valleys( 4993 predicted_classes, 4994 predicted_error=predicted_error, 4995 min_frames=min_frames_al, 4996 threshold=thr, 4997 visibility_min_score=visibility_min_score, 4998 visibility_min_frac=visibility_min_frac, 4999 error_threshold=error_threshold, 5000 main_class=class_index, 5001 low=False, 5002 threshold_diff=thr_diff, 5003 min_frames_error=min_frames_al, 5004 hysteresis=hysteresis, 5005 ) 5006 ) 5007 if len(valleys) > 0: 5008 excluded = task.dataset("train").valleys_union(valleys) 5009 al_points = task.dataset("train").valleys_intersection([included, excluded]) 5010 else: 5011 class_index = self._episode(error_episode).get_class_ind(error_class) 5012 print("generating active learning intervals...") 5013 al_points = task.dataset("train").find_valleys( 5014 predicted_error, 5015 min_frames=min_frames_al, 5016 threshold=error_threshold, 5017 visibility_min_score=visibility_min_score, 5018 visibility_min_frac=visibility_min_frac, 5019 main_class=class_index, 5020 low=True, 5021 threshold_diff=error_threshold_diff, 5022 min_frames_error=min_frames_al, 5023 hysteresis=error_hysteresis, 5024 ) 5025 for v_id in al_points: 5026 clip_dict = defaultdict(lambda: []) 5027 res = [] 5028 for x in al_points[v_id]: 5029 clip_dict[x[-1]].append(x) 5030 for clip_id in clip_dict: 5031 clip_dict[clip_id] = sorted(clip_dict[clip_id]) 5032 i = 0 5033 j = 1 5034 while j < len(clip_dict[clip_id]): 5035 end = clip_dict[clip_id][i][1] 5036 start = clip_dict[clip_id][j][0] 5037 if start - end < 30: 5038 clip_dict[clip_id][i][1] = clip_dict[clip_id][j][1] 5039 else: 5040 res.append(clip_dict[clip_id][i]) 5041 i = j 5042 j += 1 5043 res.append(clip_dict[clip_id][i]) 5044 al_points[v_id] = sorted(res) 5045 return al_points 5046 5047 def _make_suggestions( 5048 self, 5049 task: TaskDispatcher, 5050 predicted_error: torch.Tensor, 5051 predicted_classes: torch.Tensor, 5052 suggestion_threshold: List, 5053 suggestion_threshold_diff: List, 5054 suggestion_hysteresis: List, 5055 suggestion_episodes: List = None, 5056 suggestion_classes: List = None, 5057 error_threshold: float = 0.5, 5058 min_frames_suggestion: int = 3, 5059 min_frames_al: int = 30, 5060 visibility_min_score: float = 0, 5061 visibility_min_frac: float = 0.7, 5062 cut_annotated: bool = False, 5063 ) -> Dict: 5064 """Make a suggestions dictionary.""" 5065 suggestions = defaultdict(lambda: {}) 5066 for class_name, thr, thr_diff, hysteresis in zip( 5067 suggestion_classes, 5068 suggestion_threshold, 5069 suggestion_threshold_diff, 5070 suggestion_hysteresis, 5071 ): 5072 episode = self._episodes().get_runs(suggestion_episodes[0])[0] 5073 class_index = self._episode(episode).get_class_ind(class_name) 5074 print(f"generating suggestions for {class_name}...") 5075 found = task.dataset("train").find_valleys( 5076 predicted_classes, 5077 smooth_interval=2, 5078 predicted_error=predicted_error, 5079 min_frames=min_frames_suggestion, 5080 threshold=thr, 5081 visibility_min_score=visibility_min_score, 5082 visibility_min_frac=visibility_min_frac, 5083 error_threshold=error_threshold, 5084 main_class=class_index, 5085 low=False, 5086 threshold_diff=thr_diff, 5087 min_frames_error=min_frames_al, 5088 hysteresis=hysteresis, 5089 cut_annotated=cut_annotated, 5090 ) 5091 for v_id in found: 5092 suggestions[v_id][class_name] = found[v_id] 5093 suggestions = dict(suggestions) 5094 return suggestions 5095 5096 def count_classes( 5097 self, 5098 load_episode: str = None, 5099 parameters_update: Dict = None, 5100 remove_saved_features: bool = False, 5101 bouts: bool = True, 5102 ) -> Dict: 5103 """Get a dictionary of class counts in different modes. 5104 5105 Parameters 5106 ---------- 5107 load_episode : str, optional 5108 the episode settings to load 5109 parameters_update : dict, optional 5110 a dictionary of parameter updates (only for "data" and "general" categories) 5111 remove_saved_features : bool, default False 5112 if `True`, the dataset that is used for computation is then deleted 5113 bouts : bool, default False 5114 if `True`, instead of frame counts segment counts are returned 5115 5116 Returns 5117 ------- 5118 class_counts : dict 5119 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 5120 class names and values are class counts (in frames) 5121 5122 """ 5123 if load_episode is None: 5124 task, parameters = self._make_task_training( 5125 episode_name="_", parameters_update=parameters_update, throwaway=True 5126 ) 5127 else: 5128 task, parameters, _ = self._make_task_prediction( 5129 "_", 5130 load_episode=load_episode, 5131 parameters_update=parameters_update, 5132 ) 5133 class_counts = task.count_classes(bouts=bouts) 5134 behaviors = task.behaviors_dict() 5135 class_counts = { 5136 kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()} 5137 for kk, vv in class_counts.items() 5138 } 5139 if remove_saved_features: 5140 self._remove_stores(parameters) 5141 return class_counts 5142 5143 def plot_class_distribution( 5144 self, 5145 parameters_update: Dict = None, 5146 frame_cutoff: int = 1, 5147 bout_cutoff: int = 1, 5148 print_full: bool = False, 5149 remove_saved_features: bool = False, 5150 save: str = None, 5151 ) -> None: 5152 """Make a class distribution plot. 5153 5154 You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset 5155 is created or loaded for the computation with the default parameters). 5156 5157 Parameters 5158 ---------- 5159 parameters_update : dict, optional 5160 a dictionary of parameter updates (only for "data" and "general" categories) 5161 frame_cutoff : int, default 1 5162 the minimum number of frames for a segment to be considered 5163 bout_cutoff : int, default 1 5164 the minimum number of bouts for a class to be considered 5165 print_full : bool, default False 5166 if `True`, the full class distribution is printed 5167 remove_saved_features : bool, default False 5168 if `True`, the dataset that is used for computation is then deleted 5169 5170 """ 5171 task, parameters = self._make_task_training( 5172 episode_name="_", parameters_update=parameters_update, throwaway=True 5173 ) 5174 cutoff = {True: bout_cutoff, False: frame_cutoff} 5175 for bouts in [True, False]: 5176 class_counts = task.count_classes(bouts=bouts) 5177 if print_full: 5178 print("Bouts:" if bouts else "Frames:") 5179 for k, v in class_counts.items(): 5180 if sum(v.values()) != 0: 5181 print(f" {k}:") 5182 values, keys = zip( 5183 *[ 5184 x 5185 for x in sorted(zip(v.values(), v.keys()), reverse=True) 5186 if x[-1] != -100 5187 ] 5188 ) 5189 for kk, vv in zip(keys, values): 5190 print(f" {task.behaviors_dict()[kk]}: {vv}") 5191 class_counts = { 5192 kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]} 5193 for kk, vv in class_counts.items() 5194 } 5195 for key, d in class_counts.items(): 5196 if sum(d.values()) != 0: 5197 values, keys = zip( 5198 *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100] 5199 ) 5200 keys = [task.behaviors_dict()[x] for x in keys] 5201 plt.bar(keys, values) 5202 plt.title(key) 5203 plt.xticks(rotation=45, ha="right") 5204 if bouts: 5205 plt.ylabel("bouts") 5206 else: 5207 plt.ylabel("frames") 5208 plt.tight_layout() 5209 5210 if save is None: 5211 plt.savefig(save) 5212 plt.close() 5213 else: 5214 plt.show() 5215 if remove_saved_features: 5216 self._remove_stores(parameters) 5217 5218 def _generate_mask( 5219 self, 5220 mask_name: str, 5221 perc_annotated: float = 0.1, 5222 parameters_update: Dict = None, 5223 remove_saved_features: bool = False, 5224 ) -> None: 5225 """Generate a real_lens for active learning simulation. 5226 5227 Parameters 5228 ---------- 5229 mask_name : str 5230 the name of the real_lens 5231 perc_annotated : float, default 0.1 5232 a 5233 5234 """ 5235 print(f"GENERATING {mask_name}") 5236 task, parameters = self._make_task_training( 5237 f"_{mask_name}", parameters_update=parameters_update, throwaway=True 5238 ) 5239 val_intervals, val_ids = task.dataset("val").get_intervals() # 1 5240 unannotated_intervals = task.dataset("train").get_unannotated_intervals() # 2 5241 unannotated_intervals = task.dataset("val").get_unannotated_intervals( 5242 first_intervals=unannotated_intervals 5243 ) 5244 ids = task.dataset("train").get_ids() 5245 mask = {video_id: {} for video_id in ids} 5246 total_all = 0 5247 total_masked = 0 5248 for video_id, clip_ids in ids.items(): 5249 for clip_id in clip_ids: 5250 frames = np.ones(task.dataset("train").get_len(video_id, clip_id)) 5251 if clip_id in val_intervals[video_id]: 5252 for start, end in val_intervals[video_id][clip_id]: 5253 frames[start:end] = 0 5254 if clip_id in unannotated_intervals[video_id]: 5255 for start, end in unannotated_intervals[video_id][clip_id]: 5256 frames[start:end] = 0 5257 annotated = np.where(frames)[0] 5258 total_all += len(annotated) 5259 masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :] 5260 total_masked += len(masked) 5261 mask[video_id][clip_id] = self._get_intervals(masked) 5262 file = { 5263 "masked": mask, 5264 "val_intervals": val_intervals, 5265 "val_ids": val_ids, 5266 "unannotated": unannotated_intervals, 5267 } 5268 self._save_mask(file, mask_name) 5269 if remove_saved_features: 5270 self._remove_stores(parameters) 5271 print("\n") 5272 # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames') 5273 5274 def _get_intervals(self, frame_indices: np.ndarray): 5275 """Get a list of intervals from a list of frame indices. 5276 5277 Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`. 5278 5279 Parameters 5280 ---------- 5281 frame_indices : np.ndarray 5282 a list of frame indices 5283 5284 Returns 5285 ------- 5286 intervals : list 5287 a list of interval boundaries 5288 5289 """ 5290 masked_intervals = [] 5291 if len(frame_indices) > 0: 5292 breaks = np.where(np.diff(frame_indices) != 1)[0] 5293 start = frame_indices[0] 5294 for k in breaks: 5295 masked_intervals.append([start, frame_indices[k] + 1]) 5296 start = frame_indices[k + 1] 5297 masked_intervals.append([start, frame_indices[-1] + 1]) 5298 return masked_intervals 5299 5300 def _update_mask_with_uncertainty( 5301 self, 5302 mask_name: str, 5303 episode_name: Union[str, None], 5304 classes: List, 5305 load_epoch: int = None, 5306 n_frames: int = 10000, 5307 method: str = "least_confidence", 5308 min_length: int = 30, 5309 augment_n: int = 0, 5310 parameters_update: Dict = None, 5311 ): 5312 """Update real_lens with frame-wise uncertainty scores for active learning. 5313 5314 Parameters 5315 ---------- 5316 mask_name : str 5317 the name of the real_lens 5318 episode_name : str 5319 the name of the episode to load 5320 classes : list 5321 a list of class names or indices; their uncertainty scores will be computed separately and stacked 5322 load_epoch : int, optional 5323 the epoch to load (by default last; if this epoch is not saved the closest checkpoint is chosen) 5324 n_frames : int, default 10000 5325 the number of frames to "annotate" 5326 method : {"least_confidence", "entropy"} 5327 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 5328 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 5329 min_length : int 5330 the minimum length (in frames) of the annotated intervals 5331 augment_n : int, default 0 5332 the number of augmentations to average over 5333 parameters_update : dict, optional 5334 the dictionary used to update the parameters from the config 5335 5336 Returns 5337 ------- 5338 score_dicts : dict 5339 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 5340 are score tensors 5341 5342 """ 5343 print(f"UPDATING {mask_name}") 5344 task, parameters, _ = self._make_task_prediction( 5345 prediction_name=mask_name, 5346 load_episode=episode_name, 5347 parameters_update=parameters_update, 5348 load_epoch=load_epoch, 5349 mode="train", 5350 ) 5351 score_tensors = task.generate_uncertainty_score(classes, augment_n, method) 5352 self._update_mask(task, mask_name, score_tensors, n_frames, min_length) 5353 print("\n") 5354 5355 def _update_mask_with_BALD( 5356 self, 5357 mask_name: str, 5358 episode_name: str, 5359 classes: List, 5360 load_epoch: int = None, 5361 augment_n: int = 0, 5362 n_frames: int = 10000, 5363 num_models: int = 10, 5364 kernel_size: int = 11, 5365 min_length: int = 30, 5366 parameters_update: Dict = None, 5367 ): 5368 """Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning. 5369 5370 Parameters 5371 ---------- 5372 mask_name : str 5373 the name of the real_lens 5374 episode_name : str 5375 the name of the episode to load 5376 classes : list 5377 a list of class names or indices; their uncertainty scores will be computed separately and stacked 5378 load_epoch : int, optional 5379 the epoch to load (by default last) 5380 augment_n : int, default 0 5381 the number of augmentations to average over 5382 n_frames : int, default 10000 5383 the number of frames to "annotate" 5384 num_models : int, default 10 5385 the number of dropout masks to apply 5386 kernel_size : int, default 11 5387 the size of the smoothing gaussian kernel 5388 min_length : int 5389 the minimum length (in frames) of the annotated intervals 5390 parameters_update : dict, optional 5391 the dictionary used to update the parameters from the config 5392 5393 Returns 5394 ------- 5395 score_dicts : dict 5396 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 5397 are score tensors 5398 5399 """ 5400 print(f"UPDATING {mask_name}") 5401 task, parameters, mode = self._make_task_prediction( 5402 mask_name, 5403 load_episode=episode_name, 5404 parameters_update=parameters_update, 5405 load_epoch=load_epoch, 5406 ) 5407 score_tensors = task.generate_bald_score( 5408 classes, augment_n, num_models, kernel_size 5409 ) 5410 self._update_mask(task, mask_name, score_tensors, n_frames, min_length) 5411 print("\n") 5412 5413 def _suggest_intervals( 5414 self, 5415 dataset: BehaviorDataset, 5416 score_tensors: Dict, 5417 n_frames: int, 5418 min_length: int, 5419 ) -> Dict: 5420 """Suggest intervals with highest score of total length `n_frames`. 5421 5422 Parameters 5423 ---------- 5424 dataset : BehaviorDataset 5425 the dataset 5426 score_tensors : dict 5427 a dictionary where keys are clip ids and values are framewise score tensors 5428 n_frames : int 5429 the number of frames to "annotate" 5430 min_length : int 5431 minimum length of suggested intervals 5432 5433 Returns 5434 ------- 5435 active_learning_intervals : Dict 5436 active learning dictionary with suggested intervals 5437 5438 """ 5439 video_intervals, _ = dataset.get_intervals() 5440 taken = { 5441 video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys() 5442 } 5443 annotated = dataset.get_annotated_intervals() 5444 for video_id in video_intervals: 5445 for clip_id in video_intervals[video_id]: 5446 taken[video_id][clip_id] = torch.zeros( 5447 dataset.get_len(video_id, clip_id) 5448 ) 5449 if video_id in annotated and clip_id in annotated[video_id]: 5450 for start, end in annotated[video_id][clip_id]: 5451 score_tensors[video_id][clip_id][:, start:end] = -10 5452 taken[video_id][clip_id][int(start) : int(end)] = 1 5453 n_frames = ( 5454 sum([(vv == 1).sum() for v in taken.values() for vv in v.values()]) 5455 + n_frames 5456 ) 5457 factor = 1 5458 threshold_start = float( 5459 torch.mean( 5460 torch.tensor( 5461 [ 5462 torch.mean( 5463 torch.tensor([torch.mean(y[y > 0]) for y in x.values()]) 5464 ) 5465 for x in score_tensors.values() 5466 ] 5467 ) 5468 ) 5469 ) 5470 while ( 5471 sum([(vv == 1).sum() for v in taken.values() for vv in v.values()]) 5472 < n_frames 5473 ): 5474 threshold = threshold_start * factor 5475 intervals = [] 5476 interval_scores = [] 5477 key1 = list(score_tensors.keys())[0] 5478 key2 = list(score_tensors[key1].keys())[0] 5479 num_scores = score_tensors[key1][key2].shape[0] 5480 for i in range(num_scores): 5481 v_dict = dataset.find_valleys( 5482 predicted=score_tensors, 5483 threshold=threshold, 5484 min_frames=min_length, 5485 main_class=i, 5486 low=False, 5487 ) 5488 for v_id, interval_list in v_dict.items(): 5489 intervals += [x + [v_id] for x in interval_list] 5490 interval_scores += [ 5491 float(torch.mean(score_tensors[v_id][clip_id][i, start:end])) 5492 for start, end, clip_id in interval_list 5493 ] 5494 intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]] 5495 i = 0 5496 while sum( 5497 [(vv == 1).sum() for v in taken.values() for vv in v.values()] 5498 ) < n_frames and i < len(intervals): 5499 start, end, clip_id, video_id = intervals[i] 5500 i += 1 5501 taken[video_id][clip_id][int(start) : int(end)] = 1 5502 factor *= 0.9 5503 if factor < 0.05: 5504 warnings.warn(f"Could not find enough frames!") 5505 break 5506 active_learning_intervals = {video_id: [] for video_id in video_intervals} 5507 for video_id in taken: 5508 for clip_id in taken[video_id]: 5509 if video_id in annotated and clip_id in annotated[video_id]: 5510 for start, end in annotated[video_id][clip_id]: 5511 taken[video_id][clip_id][int(start) : int(end)] = 0 5512 if (taken[video_id][clip_id] == 1).sum() == 0: 5513 continue 5514 indices = np.where(taken[video_id][clip_id].numpy())[0] 5515 boundaries = self._get_intervals(indices) 5516 active_learning_intervals[video_id] += [ 5517 [start, end, clip_id] for start, end in boundaries 5518 ] 5519 return active_learning_intervals 5520 5521 def _update_mask( 5522 self, 5523 task: TaskDispatcher, 5524 mask_name: str, 5525 score_tensors: Dict, 5526 n_frames: int, 5527 min_length: int, 5528 ) -> None: 5529 """Update the real_lens with intervals with the highest score of total length `n_frames`. 5530 5531 Parameters 5532 ---------- 5533 task : TaskDispatcher 5534 the task dispatcher object 5535 mask_name : str 5536 the name of the real_lens 5537 score_tensors : dict 5538 a dictionary where keys are clip ids and values are framewise score tensors 5539 n_frames : int 5540 the number of frames to "annotate" 5541 min_length : int 5542 the minimum length of the annotated intervals 5543 5544 """ 5545 mask = self._load_mask(mask_name) 5546 video_intervals, _ = task.dataset("train").get_intervals() 5547 masked = { 5548 video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys() 5549 } 5550 total_masked = 0 5551 total_all = 0 5552 for video_id in video_intervals: 5553 for clip_id in video_intervals[video_id]: 5554 masked[video_id][clip_id] = torch.zeros( 5555 task.dataset("train").get_len(video_id, clip_id) 5556 ) 5557 if ( 5558 video_id in mask["unannotated"] 5559 and clip_id in mask["unannotated"][video_id] 5560 ): 5561 for start, end in mask["unannotated"][video_id][clip_id]: 5562 score_tensors[video_id][clip_id][:, start:end] = -10 5563 masked[video_id][clip_id][int(start) : int(end)] = 1 5564 if ( 5565 video_id in mask["val_intervals"] 5566 and clip_id in mask["val_intervals"][video_id] 5567 ): 5568 for start, end in mask["val_intervals"][video_id][clip_id]: 5569 score_tensors[video_id][clip_id][:, start:end] = -10 5570 masked[video_id][clip_id][int(start) : int(end)] = 1 5571 total_all += torch.sum(masked[video_id][clip_id] == 0) 5572 if video_id in mask["masked"] and clip_id in mask["masked"][video_id]: 5573 # print(f'{real_lens["masked"][video_id][clip_id]=}') 5574 for start, end in mask["masked"][video_id][clip_id]: 5575 masked[video_id][clip_id][int(start) : int(end)] = 1 5576 total_masked += end - start 5577 old_n_frames = sum( 5578 [(vv == 0).sum() for v in masked.values() for vv in v.values()] 5579 ) 5580 n_frames = old_n_frames + n_frames 5581 factor = 1 5582 while ( 5583 sum([(vv == 0).sum() for v in masked.values() for vv in v.values()]) 5584 < n_frames 5585 ): 5586 threshold = float( 5587 torch.mean( 5588 torch.tensor( 5589 [ 5590 torch.mean( 5591 torch.tensor([torch.mean(y[y > 0]) for y in x.values()]) 5592 ) 5593 for x in score_tensors.values() 5594 ] 5595 ) 5596 ) 5597 ) 5598 threshold = threshold * factor 5599 intervals = [] 5600 interval_scores = [] 5601 key1 = list(score_tensors.keys())[0] 5602 key2 = list(score_tensors[key1].keys())[0] 5603 num_scores = score_tensors[key1][key2].shape[0] 5604 for i in range(num_scores): 5605 v_dict = task.dataset("train").find_valleys( 5606 predicted=score_tensors, 5607 threshold=threshold, 5608 min_frames=min_length, 5609 main_class=i, 5610 low=False, 5611 ) 5612 for v_id, interval_list in v_dict.items(): 5613 intervals += [x + [v_id] for x in interval_list] 5614 interval_scores += [ 5615 float(torch.mean(score_tensors[v_id][clip_id][i, start:end])) 5616 for start, end, clip_id in interval_list 5617 ] 5618 intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]] 5619 i = 0 5620 while sum( 5621 [(vv == 0).sum() for v in masked.values() for vv in v.values()] 5622 ) < n_frames and i < len(intervals): 5623 start, end, clip_id, video_id = intervals[i] 5624 i += 1 5625 masked[video_id][clip_id][int(start) : int(end)] = 0 5626 factor *= 0.9 5627 if factor < 0.05: 5628 warnings.warn(f"Could not find enough frames!") 5629 break 5630 mask["masked"] = {video_id: {} for video_id in video_intervals} 5631 total_masked_new = 0 5632 for video_id in masked: 5633 for clip_id in masked[video_id]: 5634 if ( 5635 video_id in mask["unannotated"] 5636 and clip_id in mask["unannotated"][video_id] 5637 ): 5638 for start, end in mask["unannotated"][video_id][clip_id]: 5639 masked[video_id][clip_id][int(start) : int(end)] = 0 5640 if ( 5641 video_id in mask["val_intervals"] 5642 and clip_id in mask["val_intervals"][video_id] 5643 ): 5644 for start, end in mask["val_intervals"][video_id][clip_id]: 5645 masked[video_id][clip_id][int(start) : int(end)] = 0 5646 indices = np.where(masked[video_id][clip_id].numpy())[0] 5647 mask["masked"][video_id][clip_id] = self._get_intervals(indices) 5648 for video_id in mask["masked"]: 5649 for clip_id in mask["masked"][video_id]: 5650 for start, end in mask["masked"][video_id][clip_id]: 5651 total_masked_new += end - start 5652 self._save_mask(mask, mask_name) 5653 with open( 5654 os.path.join( 5655 self.project_path, "results", f"{mask_name}.txt", encoding="utf-8" 5656 ), 5657 "a", 5658 ) as f: 5659 f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n") 5660 print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}") 5661 5662 def _visualize_results_label( 5663 self, 5664 episode_name: str, 5665 label: str, 5666 load_epoch: int = None, 5667 parameters_update: Dict = None, 5668 add_legend: bool = True, 5669 ground_truth: bool = True, 5670 hide_axes: bool = False, 5671 width: float = 10, 5672 whole_video: bool = False, 5673 transparent: bool = False, 5674 num_plots: int = 1, 5675 smooth_interval: int = 0, 5676 ): 5677 other_path = os.path.join(self.project_path, "results", "other") 5678 if not os.path.exists(other_path): 5679 os.mkdir(other_path) 5680 if parameters_update is None: 5681 parameters_update = {} 5682 if "model" in parameters_update.keys(): 5683 raise ValueError("Cannot change model parameters after training!") 5684 task, parameters, _ = self._make_task_prediction( 5685 "_", 5686 load_episode=episode_name, 5687 parameters_update=parameters_update, 5688 load_epoch=load_epoch, 5689 mode="val", 5690 ) 5691 for i in range(num_plots): 5692 print(i) 5693 task._visualize_results_label( 5694 smooth_interval=smooth_interval, 5695 label=label, 5696 save_path=os.path.join( 5697 other_path, f"{episode_name}_prediction_{i}.jpg" 5698 ), 5699 add_legend=add_legend, 5700 ground_truth=ground_truth, 5701 hide_axes=hide_axes, 5702 whole_video=whole_video, 5703 transparent=transparent, 5704 dataset="val", 5705 width=width, 5706 title=str(i), 5707 ) 5708 5709 def plot_confusion_matrix( 5710 self, 5711 episode_name: str, 5712 load_epoch: int = None, 5713 parameters_update: Dict = None, 5714 metric: str = "recall", 5715 mode: str = "val", 5716 remove_saved_features: bool = False, 5717 save_path: str = None, 5718 cmap: str = "viridis", 5719 ) -> Tuple[ndarray, Iterable]: 5720 """Make a confusion matrix plot and return the data. 5721 5722 If the annotation is non-exclusive, only false positive labels are considered. 5723 5724 Parameters 5725 ---------- 5726 episode_name : str 5727 the name of the episode to load 5728 load_epoch : int, optional 5729 the index of the epoch to load (by default the last one is loaded) 5730 parameters_update : dict, optional 5731 a dictionary of parameter updates (only for "data" and "general" categories) 5732 metric : {"recall", "precision"} 5733 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 5734 into account, and if `type` is `"precision"`, only false negatives 5735 mode : {'val', 'all', 'test', 'train'} 5736 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 5737 remove_saved_features : bool, default False 5738 if `True`, the dataset that is used for computation is then deleted 5739 5740 Returns 5741 ------- 5742 confusion_matrix : np.ndarray 5743 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 5744 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 5745 `N_i` is the number of frames that have the i-th label in the ground truth 5746 classes : list 5747 a list of labels 5748 5749 """ 5750 task, parameters, mode = self._make_task_prediction( 5751 "_", 5752 load_episode=episode_name, 5753 load_epoch=load_epoch, 5754 parameters_update=parameters_update, 5755 mode=mode, 5756 ) 5757 dataset = task.dataset(mode) 5758 prediction = task.predict(dataset, raw_output=True) 5759 confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type) 5760 if remove_saved_features: 5761 self._remove_stores(parameters) 5762 fig, ax = plt.subplots(figsize=(len(classes), len(classes))) 5763 ax.imshow(confusion_matrix, cmap=cmap) 5764 # Show all ticks and label them with the respective list entries 5765 ax.set_xticks(np.arange(len(classes))) 5766 ax.set_xticklabels(classes) 5767 ax.set_yticks(np.arange(len(classes))) 5768 ax.set_yticklabels(classes) 5769 # Rotate the tick labels and set their alignment. 5770 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 5771 # Loop over data dimensions and create text annotations. 5772 for i in range(len(classes)): 5773 for j in range(len(classes)): 5774 ax.text( 5775 j, 5776 i, 5777 np.round(confusion_matrix[i, j], 2), 5778 ha="center", 5779 va="center", 5780 color="w", 5781 ) 5782 if metric is not None: 5783 ax.set_title(f"{metric} {episode_name}") 5784 else: 5785 ax.set_title(episode_name) 5786 fig.tight_layout() 5787 if save_path is None: 5788 plt.show() 5789 else: 5790 plt.savefig(save_path) 5791 plt.close() 5792 return confusion_matrix, classes 5793 5794 def _plot_ethograms_gt_pred( 5795 self, 5796 data_gt: dict, 5797 data_pred: dict, 5798 labels_gt: list, 5799 labels_pred: list, 5800 start: int = 0, 5801 end: int = -1, 5802 cmap_pred: str = "binary", 5803 cmap_gt: str = "binary", 5804 save: str = None, 5805 fontsize=22, 5806 time_mode="frames", 5807 fps: int = None, 5808 ) -> None: 5809 """Plot ethograms from start to end time (in frames), mode can be prediction or ground truth depending on the data format.""" 5810 # print(data.keys()) 5811 best_pred = ( 5812 data_pred[list(data_pred.keys())[0]].numpy() > 0.5 5813 ) # Threshold the predictions 5814 data_gt = binarize_data(data_gt, max_frame=end) 5815 5816 # Crop data to min length 5817 if end < 0: 5818 end = min(data_gt.shape[1], best_pred.shape[1]) 5819 data_gt = data_gt[:, :end] 5820 best_pred = best_pred[:, :end] 5821 5822 # Reorder behaviors 5823 ind_gt = [] 5824 ind_pred = [] 5825 labels_pred = [labels_pred[i] for i in range(len(labels_pred))] 5826 labels_pred = np.roll( 5827 labels_pred, 1 5828 ).tolist() 5829 check_gt = np.where(np.sum(data_gt, axis=1) > 0)[0] 5830 check_pred = np.where(np.sum(best_pred, axis=1) > 0)[0] 5831 for k, gt_beh in enumerate(labels_gt): 5832 if gt_beh in labels_pred: 5833 j = labels_pred.index(gt_beh) 5834 if not k in check_gt and not j in check_pred: 5835 continue 5836 ind_gt.append(labels_gt.index(gt_beh)) 5837 ind_pred.append(j) 5838 # Create label list 5839 labels = np.array(labels_gt)[ind_gt] 5840 assert (labels == np.array(labels_pred)[ind_pred]).all() 5841 5842 # # Create image 5843 image_pred = best_pred[ind_pred].astype(float) 5844 image_gt = data_gt[ind_gt] 5845 5846 f, axs = plt.subplots( 5847 len(labels), 1, figsize=(5 * len(labels), 15), sharex=True 5848 ) 5849 end = image_gt.shape[1] if end < 0 else end 5850 for i, (ax, label) in enumerate(zip(axs, labels)): 5851 5852 im1 = np.array([image_gt[i], np.ones_like(image_gt[i]) * (-1)]) 5853 im1 = np.ma.masked_array(im1, im1 < 0) 5854 5855 im2 = np.array([np.ones_like(image_pred[i]) * (-1), image_pred[i]]) 5856 im2 = np.ma.masked_array(im2, im2 < 0) 5857 5858 ax.imshow(im1, aspect="auto", cmap=cmap_gt, interpolation="nearest") 5859 ax.imshow(im2, aspect="auto", cmap=cmap_pred, interpolation="nearest") 5860 5861 ax.set_yticks(np.arange(2), ["GT", "Pred"], fontsize=fontsize) 5862 ax.tick_params(axis="x", labelsize=fontsize) 5863 ax.set_ylabel(label, fontsize=fontsize) 5864 if time_mode == "frames": 5865 ax.set_xlabel("Num Frames", fontsize=fontsize) 5866 elif time_mode == "seconds": 5867 assert not fps is None, "Please provide fps" 5868 ax.set_xlabel("Time (s)", fontsize=fontsize) 5869 ax.set_xticks( 5870 np.linspace(0, end, 10), 5871 np.linspace(0, end / fps, 10).astype(np.int32), 5872 ) 5873 5874 ax.set_xlim(start, end) 5875 5876 if save is None: 5877 plt.show() 5878 else: 5879 plt.savefig(save) 5880 plt.close() 5881 5882 def plot_ethograms( 5883 self, 5884 episode_name: str, 5885 prediction_name: str, 5886 start: int = 0, 5887 end: int = -1, 5888 save_path: str = None, 5889 cmap_pred: str = "binary", 5890 cmap_gt: str = "binary", 5891 fontsize: int = 22, 5892 time_mode: str = "frames", 5893 fps: int = None, 5894 ): 5895 """Plot ethograms from start to end time (in frames) for ground truth and prediction""" 5896 params = self._read_parameters(catch_blanks=False) 5897 parameters = self._get_data_pars( 5898 params, 5899 ) 5900 if not save_path is None: 5901 os.makedirs(save_path, exist_ok=True) 5902 gt_files = [ 5903 f for f in self.data_path if f.endswith(parameters["annotation_suffix"]) 5904 ] 5905 pred_path = os.path.join( 5906 self.project_path, "results", "predictions", prediction_name 5907 ) 5908 pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)] 5909 for pred_path in pred_paths: 5910 predictions = load_pickle(pred_path) 5911 behaviors = self.get_behavior_dictionary(episode_name) 5912 gt_filename = os.path.basename(pred_path).replace( 5913 "_".join(["_" + prediction_name, "prediction.pickle"]), 5914 parameters["annotation_suffix"], 5915 ) 5916 if os.path.exists(os.path.join(self.data_path, gt_filename)): 5917 gt_data = load_pickle(os.path.join(self.data_path, gt_filename)) 5918 5919 self._plot_ethograms_gt_pred( 5920 gt_data, 5921 predictions, 5922 gt_data[1], 5923 behaviors, 5924 start=start, 5925 end=end, 5926 save=os.path.join( 5927 save_path, 5928 os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred", 5929 ), 5930 cmap_pred=cmap_pred, 5931 cmap_gt=cmap_gt, 5932 fontsize=fontsize, 5933 time_mode=time_mode, 5934 fps=fps, 5935 ) 5936 else: 5937 print("GT file not found") 5938 5939 def _create_side_panel(self, height, width, labels_pred, preds, labels_gt, gt=None): 5940 """Create a side panel for video annotation display. 5941 5942 Parameters 5943 ---------- 5944 height : int 5945 the height of the panel 5946 width : int 5947 the width of the panel 5948 labels_pred : list 5949 the list of predicted behavior labels 5950 preds : array-like 5951 the prediction values for each behavior 5952 labels_gt : list 5953 the list of ground truth behavior labels 5954 gt : array-like, optional 5955 the ground truth values for each behavior 5956 5957 Returns 5958 ------- 5959 side_panel : np.ndarray 5960 the created side panel as an image array 5961 5962 """ 5963 side_panel = np.ones((height, int(width / 4), 3), dtype=np.uint8) * 255 5964 5965 beh_indices = np.where(preds)[0] 5966 for i, label in enumerate(labels_pred): 5967 color = (0, 0, 0) 5968 if i in beh_indices: 5969 color = (0, 255, 0) 5970 cv2.putText( 5971 side_panel, 5972 label, 5973 (10, 50 + 50 * i), 5974 cv2.FONT_HERSHEY_SIMPLEX, 5975 1, 5976 color, 5977 2, 5978 cv2.LINE_AA, 5979 ) 5980 if gt is not None: 5981 beh_indices_gt = np.where(gt)[0] 5982 for i, label in enumerate(labels_gt): 5983 color = (0, 0, 0) 5984 if i in beh_indices_gt: 5985 color = (0, 255, 0) 5986 cv2.putText( 5987 side_panel, 5988 label, 5989 (10, 50 + 50 * i + 80 * len(labels_pred)), 5990 cv2.FONT_HERSHEY_SIMPLEX, 5991 1, 5992 color, 5993 2, 5994 cv2.LINE_AA, 5995 ) 5996 return side_panel 5997 5998 def create_annotated_video( 5999 self, 6000 prediction_file_paths: list, 6001 video_file_paths: list, 6002 episode_name: str, # To get the list of behaviors 6003 ground_truth_file_paths: list = None, 6004 pred_thresh: float = 0.5, 6005 start: int = 0, 6006 end: int = -1, 6007 ): 6008 """Create a video with the predictions overlaid on the video""" 6009 for k, (pred_path, vid_path) in enumerate( 6010 zip(prediction_file_paths, video_file_paths) 6011 ): 6012 print("Generating video for :", os.path.basename(vid_path)) 6013 predictions = load_pickle(pred_path) 6014 best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh 6015 behaviors = self.get_behavior_dictionary(episode_name) 6016 # Load video 6017 labels_pred = [behaviors[i] for i in range(len(behaviors))] 6018 labels_pred = np.roll( 6019 labels_pred, 1 6020 ).tolist() 6021 6022 gt_data = None 6023 if ground_truth_file_paths is not None: 6024 gt_data = load_pickle(ground_truth_file_paths[k]) 6025 labels_gt = gt_data[1] 6026 gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1]) 6027 6028 cap = cv2.VideoCapture(vid_path) 6029 cap.set(cv2.CAP_PROP_POS_FRAMES, start) 6030 end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end 6031 fps = cap.get(cv2.CAP_PROP_FPS) 6032 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 6033 height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 6034 fourcc = cv2.VideoWriter_fourcc(*"mp4v") 6035 out = cv2.VideoWriter( 6036 os.path.join( 6037 os.path.dirname(vid_path), 6038 os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4", 6039 ), 6040 fourcc, 6041 fps, 6042 # (width + int(width/4) , height), 6043 (600, 300), 6044 ) 6045 count = 0 6046 bar = tqdm(total=end - start) 6047 while cap.isOpened(): 6048 ret, frame = cap.read() 6049 if not ret: 6050 break 6051 6052 side_panel = self._create_side_panel( 6053 height, 6054 width, 6055 labels_pred, 6056 best_pred[:, count], 6057 labels_gt, 6058 gt_data[:, count], 6059 ) 6060 frame = np.concatenate((frame, side_panel), axis=1) 6061 frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25) 6062 out.write(frame) 6063 count += 1 6064 bar.update(1) 6065 6066 if count > end: 6067 break 6068 6069 cap.release() 6070 out.release() 6071 cv2.destroyAllWindows() 6072 6073 def plot_predictions( 6074 self, 6075 episode_name: str, 6076 load_epoch: int = None, 6077 parameters_update: Dict = None, 6078 add_legend: bool = True, 6079 ground_truth: bool = True, 6080 colormap: str = "dlc2action", 6081 hide_axes: bool = False, 6082 min_classes: int = 1, 6083 width: float = 10, 6084 whole_video: bool = False, 6085 transparent: bool = False, 6086 drop_classes: Set = None, 6087 search_classes: Set = None, 6088 num_plots: int = 1, 6089 remove_saved_features: bool = False, 6090 smooth_interval_prediction: int = 0, 6091 data_path: str = None, 6092 file_paths: Set = None, 6093 mode: str = "val", 6094 font_size: float = None, 6095 window_size: int = 400, 6096 ) -> None: 6097 """Visualize random predictions. 6098 6099 Parameters 6100 ---------- 6101 episode_name : str 6102 the name of the episode to load 6103 load_epoch : int, optional 6104 the epoch to load (by default last) 6105 parameters_update : dict, optional 6106 parameter update dictionary 6107 add_legend : bool, default True 6108 if True, legend will be added to the plot 6109 ground_truth : bool, default True 6110 if True, ground truth will be added to the plot 6111 colormap : str, default 'Accent' 6112 the `matplotlib` colormap to use 6113 hide_axes : bool, default True 6114 if `True`, the axes will be hidden on the plot 6115 min_classes : int, default 1 6116 the minimum number of classes in a displayed interval 6117 width : float, default 10 6118 the width of the plot 6119 whole_video : bool, default False 6120 if `True`, whole videos are plotted instead of segments 6121 transparent : bool, default False 6122 if `True`, the background on the plot is transparent 6123 drop_classes : set, optional 6124 a set of class names to not be displayed 6125 search_classes : set, optional 6126 if given, only intervals where at least one of the classes is in ground truth will be shown 6127 num_plots : int, default 1 6128 the number of plots to make 6129 remove_saved_features : bool, default False 6130 if `True`, the dataset will be deleted after computation 6131 smooth_interval_prediction : int, default 0 6132 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) 6133 data_path : str, optional 6134 the data path to run the prediction for 6135 file_paths : set, optional 6136 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 6137 for 6138 mode : {'all', 'test', 'val', 'train'} 6139 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 6140 6141 """ 6142 plot_path = os.path.join(self.project_path, "results", "plots") 6143 task, parameters, mode = self._make_task_prediction( 6144 "_", 6145 load_episode=episode_name, 6146 parameters_update=parameters_update, 6147 load_epoch=load_epoch, 6148 data_path=data_path, 6149 file_paths=file_paths, 6150 mode=mode, 6151 ) 6152 os.makedirs(plot_path, exist_ok=True) 6153 task.visualize_results( 6154 save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"), 6155 add_legend=add_legend, 6156 ground_truth=ground_truth, 6157 colormap=colormap, 6158 hide_axes=hide_axes, 6159 min_classes=min_classes, 6160 whole_video=whole_video, 6161 transparent=transparent, 6162 dataset=mode, 6163 drop_classes=drop_classes, 6164 search_classes=search_classes, 6165 width=width, 6166 smooth_interval_prediction=smooth_interval_prediction, 6167 font_size=font_size, 6168 num_plots=num_plots, 6169 window_size=window_size, 6170 ) 6171 if remove_saved_features: 6172 self._remove_stores(parameters) 6173 6174 def create_video_from_labels( 6175 self, 6176 video_dir_path: str, 6177 mode="ground_truth", 6178 prediction_name: str = None, 6179 save_path: str = None, 6180 ): 6181 if save_path is None: 6182 save_path = os.path.join( 6183 self.project_path, "results", f"annotated_videos_from_{mode}" 6184 ) 6185 os.makedirs(save_path, exist_ok=True) 6186 6187 params = self._read_parameters(catch_blanks=False) 6188 6189 if mode == "ground_truth": 6190 source_dir = self.annotation_path 6191 annotation_suffix = params["data"]["annotation_suffix"] 6192 elif mode == "prediction": 6193 assert ( 6194 not prediction_name is None 6195 ), "Please provide a prediction name with mode 'prediction'" 6196 source_dir = os.path.join( 6197 self.project_path, "results", "predictions", prediction_name 6198 ) 6199 annotation_suffix = f"_{prediction_name}_prediction.pickle" 6200 6201 video_annotation_pairs = [ 6202 ( 6203 os.path.join(video_dir_path, f), 6204 os.path.join( 6205 source_dir, f.replace(f.split(".")[-1], annotation_suffix) 6206 ), 6207 ) 6208 for f in os.listdir(video_dir_path) 6209 if os.path.exists( 6210 os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix)) 6211 ) 6212 ] 6213 6214 for video_file, annotation_file in tqdm(video_annotation_pairs): 6215 if not os.path.exists(video_file): 6216 print(f"Video file {video_file} does not exist, skipping.") 6217 continue 6218 if not os.path.exists(annotation_file): 6219 print(f"Annotation file {annotation_file} does not exist, skipping.") 6220 continue 6221 6222 if annotation_file.endswith(".pickle"): 6223 annotations = load_pickle(annotation_file) 6224 elif annotation_file.endswith(".csv"): 6225 annotations = pd.read_csv(annotation_file) 6226 6227 if mode == "ground_truth": 6228 behaviors = annotations[1] 6229 annot_data = annotations[3] 6230 elif mode == "predictions": 6231 behaviors = list(annotations["classes"].values()) 6232 annot_data = [ 6233 annotations[key] 6234 for key in annotations.keys() 6235 if key not in ["classes", "min_frame", "max_frame"] 6236 ] 6237 if params["general"]["exclusive"]: 6238 annot_data = [np.argmax(annot, axis=1) for annot in annot_data] 6239 seqs = [ 6240 [ 6241 self._bin_array_to_sequences(annot, target_value=k) 6242 for k in range(len(behaviors)) 6243 ] 6244 for annot in annot_data 6245 ] 6246 else: 6247 annot_data = [np.where(annot > 0.5)[0] for annot in annot_data] 6248 seqs = [ 6249 self._bin_array_to_sequences(annot, target_value=1) 6250 for annot in annot_data 6251 ] 6252 annotations = ["", "", seqs] 6253 6254 for individual in annotations[3]: 6255 for behavior in annotations[3][individual]: 6256 intervals = annotations[3][individual][behavior] 6257 self._extract_videos( 6258 video_file, 6259 intervals, 6260 behavior, 6261 individual, 6262 save_path, 6263 resolution=(640, 480), 6264 fps=30, 6265 ) 6266 6267 def _bin_array_to_sequences( 6268 self, annot_data: List[np.ndarray], target_value: int 6269 ) -> List[List[Tuple[int, int]]]: 6270 is_target = annot_data == target_value 6271 changes = np.diff(np.concatenate(([False], is_target, [False]))) 6272 indices = np.where(changes)[0].reshape(-1, 2) 6273 subsequences = [list(range(start, end)) for start, end in indices] 6274 return subsequences 6275 6276 def _extract_videos( 6277 self, 6278 video_file: str, 6279 intervals: np.ndarray, 6280 behavior: str, 6281 individual: str, 6282 video_dir: str, 6283 resolution: Tuple[int, int] = (640, 480), 6284 fps: int = 30, 6285 ) -> None: 6286 """Extract frames from a video file from frames in between intervals in behavior folder for a given individual""" 6287 cap = cv2.VideoCapture(video_file) 6288 print("Extracting frames from", video_file) 6289 6290 for start, end, confusion in tqdm(intervals): 6291 6292 frame_count = start 6293 assert start < end, "Start frame should be less than end frame" 6294 if confusion > 0.5: 6295 continue 6296 cap.set(cv2.CAP_PROP_POS_FRAMES, start) 6297 output_file = os.path.join( 6298 video_dir, 6299 individual, 6300 behavior, 6301 os.path.splitext(os.path.basename(video_file))[0] 6302 + f"vid_{individual}_{behavior}_{start:05d}_{end:05d}.mp4", 6303 ) 6304 fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec, e.g., 'XVID', 'MJPG' 6305 out = cv2.VideoWriter( 6306 output_file, fourcc, fps, (resolution[0], resolution[1]) 6307 ) 6308 while cap.isOpened(): 6309 ret, frame = cap.read() 6310 if not ret: 6311 break 6312 6313 # Resize large frames 6314 frame = cv2.resize(frame, (640, 480)) 6315 out.write(frame) 6316 6317 frame_count += 1 6318 # Break if end frame is reached or max frames per behavior is reached 6319 if frame_count == end: 6320 break 6321 if frame_count <= 2: 6322 os.remove(output_file) 6323 # cap.release() 6324 out.release() 6325 6326 def create_metadata_backup(self) -> None: 6327 """Create a copy of the meta files.""" 6328 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 6329 meta_path = os.path.join(self.project_path, "meta") 6330 if os.path.exists(meta_copy_path): 6331 shutil.rmtree(meta_copy_path) 6332 os.mkdir(meta_copy_path) 6333 for file in os.listdir(meta_path): 6334 if file == "backup": 6335 continue 6336 if os.path.isdir(os.path.join(meta_path, file)): 6337 continue 6338 shutil.copy( 6339 os.path.join(meta_path, file), os.path.join(meta_copy_path, file) 6340 ) 6341 6342 def load_metadata_backup(self) -> None: 6343 """Load from previously created meta data backup (in case of corruption).""" 6344 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 6345 meta_path = os.path.join(self.project_path, "meta") 6346 for file in os.listdir(meta_copy_path): 6347 shutil.copy( 6348 os.path.join(meta_copy_path, file), os.path.join(meta_path, file) 6349 ) 6350 6351 def get_behavior_dictionary(self, episode_name: str) -> Dict: 6352 """Get the behavior dictionary for an episode. 6353 6354 Parameters 6355 ---------- 6356 episode_name : str 6357 the name of the episode 6358 6359 Returns 6360 ------- 6361 behaviors_dictionary : dict 6362 a dictionary where keys are label indices and values are label names 6363 6364 """ 6365 return self._episode(episode_name).get_behaviors_dict() 6366 6367 def import_episodes( 6368 self, 6369 episodes_directory: str, 6370 name_map: Dict = None, 6371 repeat_policy: str = "error", 6372 ) -> None: 6373 """Import episodes exported with `Project.export_episodes`. 6374 6375 Parameters 6376 ---------- 6377 episodes_directory : str 6378 the path to the exported episodes directory 6379 name_map : dict, optional 6380 a name change dictionary for the episodes: keys are old names, values are new names 6381 repeat_policy : {'error', 'skip', 'force'}, default 'error' 6382 the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates, 6383 'force' overwrites existing episodes 6384 6385 """ 6386 if name_map is None: 6387 name_map = {} 6388 episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle")) 6389 to_remove = [] 6390 import_string = "Imported episodes: " 6391 for episode_name in episodes.index: 6392 if episode_name in name_map: 6393 import_string += f"{episode_name} " 6394 episode_name = name_map[episode_name] 6395 import_string += f"({episode_name}), " 6396 else: 6397 import_string += f"{episode_name}, " 6398 try: 6399 self._check_episode_validity(episode_name, allow_doublecolon=True) 6400 except ValueError as e: 6401 if str(e).endswith("is already taken!"): 6402 if repeat_policy == "skip": 6403 to_remove.append(episode_name) 6404 elif repeat_policy == "force": 6405 self.remove_episode(episode_name) 6406 elif repeat_policy == "error": 6407 raise ValueError( 6408 f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it" 6409 ) 6410 else: 6411 raise ValueError( 6412 f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']" 6413 ) 6414 episodes = episodes.drop(index=to_remove) 6415 self._episodes().update( 6416 episodes, 6417 name_map=name_map, 6418 force=(repeat_policy == "force"), 6419 data_path=self.data_path, 6420 annotation_path=self.annotation_path, 6421 ) 6422 for episode_name in episodes.index: 6423 if episode_name in name_map: 6424 new_episode_name = name_map[episode_name] 6425 else: 6426 new_episode_name = episode_name 6427 model_dir = os.path.join( 6428 self.project_path, "results", "model", new_episode_name 6429 ) 6430 old_model_dir = os.path.join(episodes_directory, "model", episode_name) 6431 if os.path.exists(model_dir): 6432 shutil.rmtree(model_dir) 6433 os.mkdir(model_dir) 6434 for file in os.listdir(old_model_dir): 6435 shutil.copyfile( 6436 os.path.join(old_model_dir, file), os.path.join(model_dir, file) 6437 ) 6438 log_file = os.path.join( 6439 self.project_path, "results", "logs", f"{new_episode_name}.txt" 6440 ) 6441 old_log_file = os.path.join( 6442 episodes_directory, "logs", f"{episode_name}.txt" 6443 ) 6444 shutil.copyfile(old_log_file, log_file) 6445 print(import_string) 6446 print("\n") 6447 6448 def export_episodes( 6449 self, episode_names: List, output_directory: str, name: str = None 6450 ) -> None: 6451 """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`. 6452 6453 Parameters 6454 ---------- 6455 episode_names : list 6456 a list of string episode names 6457 output_directory : str 6458 the path to the directory where the episodes will be saved 6459 name : str, optional 6460 the name of the episodes directory (by default `exported_episodes`) 6461 6462 """ 6463 if name is None: 6464 name = "exported_episodes" 6465 if os.path.exists( 6466 os.path.join(output_directory, name + ".zip") 6467 ) or os.path.exists(os.path.join(output_directory, name)): 6468 i = 1 6469 while os.path.exists( 6470 os.path.join(output_directory, name + f"_{i}.zip") 6471 ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")): 6472 i += 1 6473 name = name + f"_{i}" 6474 dest_dir = os.path.join(output_directory, name) 6475 os.mkdir(dest_dir) 6476 os.mkdir(os.path.join(dest_dir, "model")) 6477 os.mkdir(os.path.join(dest_dir, "logs")) 6478 runs = [] 6479 for episode in episode_names: 6480 runs += self._episodes().get_runs(episode) 6481 for run in runs: 6482 shutil.copytree( 6483 os.path.join(self.project_path, "results", "model", run), 6484 os.path.join(dest_dir, "model", run), 6485 ) 6486 shutil.copyfile( 6487 os.path.join(self.project_path, "results", "logs", f"{run}.txt"), 6488 os.path.join(dest_dir, "logs", f"{run}.txt"), 6489 ) 6490 data = self._episodes().get_subset(runs) 6491 data.to_pickle(os.path.join(dest_dir, "episodes.pickle")) 6492 6493 def get_results_table( 6494 self, 6495 episode_names: List, 6496 metrics: List = None, 6497 mode: str = "mean", # Choose between ["mean", "statistics", "detail"] 6498 print_results: bool = True, 6499 classes: List = None, 6500 ): 6501 """Generate a `pandas` dataframe with a summary of episode results. 6502 6503 Parameters 6504 ---------- 6505 episode_names : list 6506 a list of names of episodes to include 6507 metrics : list, optional 6508 a list of metric names to include 6509 mode : bool, optional 6510 the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean" 6511 print_results : bool, optional 6512 if True, the results will be printed to the console, by default True 6513 classes : list, optional 6514 a list of names of classes to include (by default all are included) 6515 6516 Returns 6517 ------- 6518 results : pd.DataFrame 6519 a table with the results 6520 6521 """ 6522 run_names = [] 6523 for episode in episode_names: 6524 run_names += self._episodes().get_runs(episode) 6525 episodes = self.list_episodes(run_names, print_results=False) 6526 metric_columns = [x for x in episodes.columns if x[0] == "results"] 6527 results_df = pd.DataFrame() 6528 if metrics is not None: 6529 metric_columns = [ 6530 x for x in metric_columns if x[1].split("_")[0] in metrics 6531 ] 6532 for episode in episode_names: 6533 results = [] 6534 metric_set = set() 6535 for run in self._episodes().get_runs(episode): 6536 beh_dict = self.get_behavior_dictionary(run) 6537 res_dict = defaultdict(lambda: {}) 6538 for column in metric_columns: 6539 if np.isnan(episodes.loc[run, column]): 6540 continue 6541 split = column[1].split("_") 6542 if split[-1].isnumeric(): 6543 beh_ind = int(split[-1]) 6544 metric_name = "_".join(split[:-1]) 6545 beh = beh_dict[beh_ind] 6546 else: 6547 beh = "average" 6548 metric_name = column[1] 6549 res_dict[beh][metric_name] = episodes.loc[run, column] 6550 metric_set.add(metric_name) 6551 if "average" not in res_dict: 6552 res_dict["average"] = {} 6553 for metric in metric_set: 6554 if metric not in res_dict["average"]: 6555 arr = [ 6556 res_dict[beh][metric] 6557 for beh in res_dict 6558 if metric in res_dict[beh] 6559 ] 6560 res_dict["average"][metric] = np.mean(arr) 6561 results.append(res_dict) 6562 episode_results = {} 6563 for metric in metric_set: 6564 for beh in results[0].keys(): 6565 if classes is not None and beh not in classes: 6566 continue 6567 arr = [] 6568 for res_dict in results: 6569 if metric in res_dict[beh]: 6570 arr.append(res_dict[beh][metric]) 6571 if len(arr) > 0: 6572 if mode == "statistics": 6573 episode_results[(beh, f"{episode} {metric} mean")] = ( 6574 np.mean(arr) 6575 ) 6576 episode_results[(beh, f"{episode} {metric} std")] = np.std( 6577 arr 6578 ) 6579 elif mode == "mean": 6580 episode_results[(beh, f"{episode} {metric}")] = np.mean(arr) 6581 elif mode == "detail": 6582 for i, val in enumerate(arr): 6583 episode_results[(beh, f"{episode}::{i} {metric}")] = val 6584 for key, value in episode_results.items(): 6585 results_df.loc[key[0], key[1]] = value 6586 if print_results: 6587 print(f"RESULTS:") 6588 print(results_df) 6589 print("\n") 6590 return results_df 6591 6592 def episode_exists(self, episode_name: str) -> bool: 6593 """Check if an episode already exists. 6594 6595 Parameters 6596 ---------- 6597 episode_name : str 6598 the episode name 6599 6600 Returns 6601 ------- 6602 exists : bool 6603 `True` if the episode exists 6604 6605 """ 6606 return self._episodes().check_name_validity(episode_name) 6607 6608 def search_exists(self, search_name: str) -> bool: 6609 """Check if a search already exists. 6610 6611 Parameters 6612 ---------- 6613 search_name : str 6614 the search name 6615 6616 Returns 6617 ------- 6618 exists : bool 6619 `True` if the search exists 6620 6621 """ 6622 return self._searches().check_name_validity(search_name) 6623 6624 def prediction_exists(self, prediction_name: str) -> bool: 6625 """Check if a prediction already exists. 6626 6627 Parameters 6628 ---------- 6629 prediction_name : str 6630 the prediction name 6631 6632 Returns 6633 ------- 6634 exists : bool 6635 `True` if the prediction exists 6636 6637 """ 6638 return self._predictions().check_name_validity(prediction_name) 6639 6640 @staticmethod 6641 def project_name_available(projects_path: str, project_name: str): 6642 """Check if a project name is available. 6643 6644 Parameters 6645 ---------- 6646 projects_path : str 6647 the path to the projects directory 6648 project_name : str 6649 the name of the project to check 6650 6651 Returns 6652 ------- 6653 available : bool 6654 `True` if the project name is available 6655 6656 """ 6657 if projects_path is None: 6658 projects_path = os.path.join(str(Path.home()), "DLC2Action") 6659 return not os.path.exists(os.path.join(projects_path, project_name)) 6660 6661 def _update_episode_metrics(self, episode_name: str, metrics: Dict): 6662 """Update meta data with evaluation results. 6663 6664 Parameters 6665 ---------- 6666 episode_name : str 6667 the name of the episode 6668 metrics : dict 6669 the metrics dictionary to update with 6670 6671 """ 6672 self._episodes().update_episode_metrics(episode_name, metrics) 6673 6674 def rename_episode(self, episode_name: str, new_episode_name: str): 6675 """Rename an episode. 6676 6677 Parameters 6678 ---------- 6679 episode_name : str 6680 the current episode name 6681 new_episode_name : str 6682 the new episode name 6683 6684 """ 6685 shutil.move( 6686 os.path.join(self.project_path, "results", "model", episode_name), 6687 os.path.join(self.project_path, "results", "model", new_episode_name), 6688 ) 6689 shutil.move( 6690 os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"), 6691 os.path.join( 6692 self.project_path, "results", "logs", f"{new_episode_name}.txt" 6693 ), 6694 ) 6695 self._episodes().rename_episode(episode_name, new_episode_name) 6696 6697 6698class _Runner: 6699 """A helper class for running hyperparameter searches.""" 6700 6701 def __init__( 6702 self, 6703 search_name: str, 6704 search_space: Dict, 6705 load_episode: str, 6706 load_epoch: int, 6707 metric: str, 6708 average: int, 6709 task: Union[TaskDispatcher, None], 6710 remove_saved_features: bool, 6711 project: Project, 6712 ): 6713 """Initialize the class. 6714 6715 Parameters 6716 ---------- 6717 task : TaskDispatcher 6718 the task dispatcher object 6719 search_name : str 6720 the name the search should be saved under 6721 search_space : dict 6722 a dictionary representing the search space; of this general structure: 6723 {'group/param_name': ('float/int/float_log/int_log', start, end), 6724 'group/param_name': ('categorical', [choices])}, e.g. 6725 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 6726 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])} 6727 load_episode : str 6728 the name of the episode to load the model from 6729 load_epoch : int 6730 the epoch to load the model from (if not provided, the last checkpoint is used) 6731 metric : str 6732 the metric to maximize/minimize (see direction) 6733 average : int 6734 the number of epochs to average the metric; if 0, the last value is taken 6735 remove_saved_features : bool 6736 if `True`, the old datasets will be deleted when data parameters change 6737 project : Project 6738 the parent `Project` instance 6739 6740 """ 6741 self.search_space = search_space 6742 self.load_episode = load_episode 6743 self.load_epoch = load_epoch 6744 self.metric = metric 6745 self.average = average 6746 self.feature_save_path = None 6747 self.remove_saved_featuress = remove_saved_features 6748 self.save_stores = project._save_stores 6749 self.remove_datasets = project.remove_saved_features 6750 self.task = task 6751 self.search_name = search_name 6752 self.update = project._update 6753 self.remove_episode = project.remove_episode 6754 self.fill = project._fill 6755 6756 def clean(self): 6757 """Remove datasets if needed. 6758 6759 This method removes saved feature datasets when the remove_saved_features flag is set. 6760 6761 """ 6762 if self.remove_saved_featuress: 6763 self.remove_datasets([os.path.basename(self.feature_save_path)]) 6764 6765 def run(self, trial, parameters): 6766 """Make a trial run. 6767 6768 Parameters 6769 ---------- 6770 trial : optuna.trial.Trial 6771 the Optuna trial object 6772 parameters : dict 6773 the base parameters dictionary 6774 6775 Returns 6776 ------- 6777 value : float 6778 the metric value for this trial 6779 6780 """ 6781 params = deepcopy(parameters) 6782 param_update = defaultdict( 6783 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: {}))) 6784 ) 6785 for full_name, space in self.search_space.items(): 6786 group, param_name = ( 6787 full_name.split("/")[0], 6788 "/".join(full_name.split("/")[1:]), 6789 ) 6790 log = space[0][-3:] == "log" 6791 if space[0].startswith("int"): 6792 value = trial.suggest_int(full_name, space[1], space[2], log=log) 6793 elif space[0].startswith("float"): 6794 value = trial.suggest_float(full_name, space[1], space[2], log=log) 6795 elif space[0] == "categorical": 6796 value = trial.suggest_categorical(full_name, space[1]) 6797 else: 6798 raise ValueError( 6799 "The search space has to be formatted as either " 6800 '("float"/"int"/"float_log"/"int_log", start, end) ' 6801 f'or ("categorical", [choices]); got {space} for {group}/{param_name}' 6802 ) 6803 if len(param_name.split("/")) == 1: 6804 param_update[group][param_name] = value 6805 else: 6806 pars = param_name.split("/") 6807 pars = [int(x) if x.isnumeric() else x for x in pars] 6808 if len(pars) == 2: 6809 param_update[group][pars[0]][pars[1]] = value 6810 elif len(pars) == 3: 6811 param_update[group][pars[0]][pars[1]][pars[2]] = value 6812 elif len(pars) == 4: 6813 param_update[group][pars[0]][pars[1]][pars[2]][pars[3]] = value 6814 param_update = {k: dict(v) for k, v in param_update.items()} 6815 params = self.update(params, param_update) 6816 self.remove_episode(f"_{self.search_name}") 6817 params = self.fill( 6818 params, 6819 f"_{self.search_name}", 6820 self.load_episode, 6821 load_epoch=self.load_epoch, 6822 only_load_model=True, 6823 ) 6824 if self.feature_save_path != params["data"]["feature_save_path"]: 6825 if self.feature_save_path is not None: 6826 self.clean() 6827 self.feature_save_path = params["data"]["feature_save_path"] 6828 self.save_stores(params) 6829 if self.task is None: 6830 self.task = TaskDispatcher(deepcopy(params)) 6831 else: 6832 self.task.update_task(params) 6833 6834 _, metrics_log = self.task.train(trial, self.metric) 6835 if self.metric in metrics_log["val"].keys(): 6836 metric_values = metrics_log["val"][self.metric] 6837 if self.average > 0: 6838 value = np.mean(sorted(metric_values)[-self.average :]) 6839 else: 6840 value = metric_values[-1] 6841 return value 6842 else: # ['accuracy', 'precision', 'f1', 'recall', 'count', 'segmental_precision', 'segmental_recall', 'segmental_f1', 'edit_distance', 'f_beta', 'segmental_f_beta', 'semisegmental_precision', 'semisegmental_recall', 'semisegmental_f1', 'pr-auc', 'semisegmental_pr-auc', 'mAP'] 6843 if self.metric in [ 6844 "f1", 6845 "precision", 6846 "recall", 6847 "accuracy", 6848 "count", 6849 "segmental_precision", 6850 "segmental_recall", 6851 "segmental_f1", 6852 "f_beta", 6853 "segmental_f_beta", 6854 "semisegmental_precision", 6855 "semisegmental_recall", 6856 "semisegmental_f1", 6857 "pr-auc", 6858 "semisegmental_pr-auc", 6859 "mAP", 6860 ]: 6861 return 0 6862 elif self.metric in ["loss", "mse", "mae", "edit_distance"]: 6863 return float("inf")
55class Project: 56 """A class to create and maintain the project files + keep track of experiments.""" 57 58 def __init__( 59 self, 60 name: str, 61 data_type: str = None, 62 annotation_type: str = "none", 63 projects_path: str = None, 64 data_path: Union[str, List] = None, 65 annotation_path: Union[str, List] = None, 66 copy: bool = False, 67 ) -> None: 68 """Initialize the class. 69 70 Parameters 71 ---------- 72 name : str 73 name of the project 74 data_type : str, optional 75 data type (run Project.data_types() to see available options; has to be provided if the project is being 76 created) 77 annotation_type : str, default 'none' 78 annotation type (run Project.annotation_types() to see available options) 79 projects_path : str, optional 80 path to the projects folder (is filled with ~/DLC2Action by default) 81 data_path : str, optional 82 path to the folder containing input files for the project (has to be provided if the project is being 83 created) 84 annotation_path : str, optional 85 path to the folder containing annotation files for the project 86 copy : bool, default False 87 if True, the files from annotation_path and data_path will be copied to the projects folder; 88 otherwise they will be moved 89 90 """ 91 if projects_path is None: 92 projects_path = os.path.join(str(Path.home()), "DLC2Action") 93 if not os.path.exists(projects_path): 94 os.mkdir(projects_path) 95 self.project_path = os.path.join(projects_path, name) 96 self.name = name 97 self.data_type = data_type 98 self.annotation_type = annotation_type 99 self.data_path = data_path 100 self.annotation_path = annotation_path 101 if not os.path.exists(self.project_path): 102 if data_type is None: 103 raise ValueError( 104 "The data_type parameter is necessary when creating a new project!" 105 ) 106 self._initialize_project( 107 data_type, annotation_type, data_path, annotation_path, copy 108 ) 109 else: 110 self.annotation_type, self.data_type = self._read_types() 111 if data_type != self.data_type and data_type is not None: 112 raise ValueError( 113 f"The project has already been initialized with data_type={self.data_type}!" 114 ) 115 if annotation_type != self.annotation_type and annotation_type != "none": 116 raise ValueError( 117 f"The project has already been initialized with annotation_type={self.annotation_type}!" 118 ) 119 self.annotation_path, data_path = self._read_paths() 120 if self.data_path is None: 121 self.data_path = data_path 122 # if data_path != self.data_path and data_path is not None: 123 # raise ValueError( 124 # f"The project has already been initialized with data_path={self.data_path}!" 125 # ) 126 if annotation_path != self.annotation_path and annotation_path is not None: 127 raise ValueError( 128 f"The project has already been initialized with annotation_path={self.annotation_path}!" 129 ) 130 self._update_configs() 131 132 def _make_prediction( 133 self, 134 prediction_name: str, 135 episode_names: List, 136 load_epochs: Union[List[int], int] = None, 137 parameters_update: Dict = None, 138 data_path: str = None, 139 file_paths: Set = None, 140 mode: str = "all", 141 augment_n: int = 0, 142 evaluate: bool = False, 143 task: TaskDispatcher = None, 144 embedding: bool = False, 145 annotation_type: str = "none", 146 ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]: 147 """Generate a prediction. 148 Parameters 149 ---------- 150 prediction_name : str 151 name of the prediction 152 episode_names : List 153 names of the episodes to use for the prediction 154 load_epochs : Union[List[int],int], optional 155 epochs to load for each episode; if a single integer is provided, it will be used for all episodes; 156 if None, the last epochs will be used 157 parameters_update : Dict, optional 158 dictionary with parameters to update the task parameters 159 data_path : str, optional 160 path to the data folder; if None, the data_path from the project will be used 161 file_paths : Set, optional 162 set of file paths to use for the prediction; if None, the data_path will be used 163 mode : str, default "all 164 mode of the prediction; can be "train", "val", "test" or "all 165 augment_n : int, default 0 166 number of augmentations to apply to the data; if 0, no augmentations are applied 167 evaluate : bool, default False 168 if True, the prediction will be evaluated and the results will be saved to the episode meta file 169 task : TaskDispatcher, optional 170 task object to use for the prediction; if None, a new task object will be created 171 embedding : bool, default False 172 if True, the prediction will be returned as an embedding 173 annotation_type : str, default "none 174 type of the annotation to use for the prediction; if "none", the annotation will not be used 175 Returns 176 ------- 177 task : TaskDispatcher 178 task object used for the prediction 179 parameters : Dict 180 parameters used for the prediction 181 mode : str 182 mode of the prediction 183 prediction : torch.Tensor 184 prediction tensor of shape (num_videos, num_behaviors, num_frames) 185 inference_time : str 186 time taken for the prediction in the format "HH:MM:SS" 187 behavior_dict : Dict 188 dictionary with behavior names and their indices 189 """ 190 191 names = [] 192 for episode_name in episode_names: 193 names += self._episodes().get_runs(episode_name) 194 if len(names) == 0: 195 warnings.warn(f"None of the episodes {episode_names} exist!") 196 names = [None] 197 if load_epochs is None: 198 load_epochs = [None for _ in names] 199 elif isinstance(load_epochs, int): 200 load_epochs = [load_epochs for _ in names] 201 assert len(load_epochs) == len( 202 names 203 ), f"Length of load_epochs ({len(load_epochs)}) must match the number of episodes ({len(names)})!" 204 prediction = None 205 decision_thresholds = None 206 time_total = 0 207 behavior_dicts = [ 208 self.get_behavior_dictionary(episode_name) for episode_name in names 209 ] 210 211 if not all( 212 [ 213 set(d.values()) == set(behavior_dicts[0].values()) 214 for d in behavior_dicts[1:] 215 ] 216 ): 217 raise ValueError( 218 f"Episodes {episode_names} have different sets of behaviors!" 219 ) 220 behaviors = list(behavior_dicts[0].values()) 221 222 for episode_name, load_epoch, behavior_dict in zip( 223 names, load_epochs, behavior_dicts 224 ): 225 print(f"episode {episode_name}") 226 task, parameters, data_mode = self._make_task_prediction( 227 prediction_name=prediction_name, 228 load_episode=episode_name, 229 parameters_update=parameters_update, 230 load_epoch=load_epoch, 231 data_path=data_path, 232 mode=mode, 233 file_paths=file_paths, 234 task=task, 235 decision_thresholds=decision_thresholds, 236 annotation_type=annotation_type, 237 ) 238 # data_mode = "train" if mode == "all" else mode 239 time_start = time.time() 240 new_pred = task.predict( 241 data_mode, 242 raw_output=True, 243 apply_primary_function=True, 244 augment_n=augment_n, 245 embedding=embedding, 246 ) 247 indices = [ 248 behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1]) 249 ] 250 new_pred = new_pred[:, indices, :] 251 time_end = time.time() 252 time_total += time_end - time_start 253 if evaluate: 254 _, metrics = task.evaluate_prediction( 255 new_pred, data=data_mode, indices=indices 256 ) 257 if mode == "val": 258 self._update_episode_metrics(episode_name, metrics) 259 if prediction is None: 260 prediction = new_pred 261 else: 262 prediction += new_pred 263 print("\n") 264 hours = int(time_total // 3600) 265 time_total -= hours * 3600 266 minutes = int(time_total // 60) 267 time_total -= minutes * 60 268 seconds = int(time_total) 269 inference_time = f"{hours}:{minutes:02}:{seconds:02}" 270 prediction /= len(names) 271 return ( 272 task, 273 parameters, 274 data_mode, 275 prediction, 276 inference_time, 277 behavior_dicts[0], 278 ) 279 280 def _make_task_prediction( 281 self, 282 prediction_name: str, 283 load_episode: str = None, 284 parameters_update: Dict = None, 285 load_epoch: int = None, 286 data_path: str = None, 287 annotation_path: str = None, 288 mode: str = "val", 289 file_paths: Set = None, 290 decision_thresholds: List = None, 291 task: TaskDispatcher = None, 292 annotation_type: str = "none", 293 ) -> Tuple[TaskDispatcher, Dict, str]: 294 """Make a `TaskDispatcher` object that will be used to generate a prediction.""" 295 if parameters_update is None: 296 parameters_update = {} 297 parameters_update_second = {} 298 if mode == "all" or data_path is not None or file_paths is not None: 299 parameters_update_second["training"] = { 300 "val_frac": 0, 301 "test_frac": 0, 302 "partition_method": "random", 303 "save_split": False, 304 "split_path": None, 305 } 306 mode = "train" 307 if decision_thresholds is not None: 308 if ( 309 len(decision_thresholds) 310 == self._episode(load_episode).get_num_classes() 311 ): 312 parameters_update_second["general"] = { 313 "threshold_value": decision_thresholds 314 } 315 else: 316 raise ValueError( 317 f"The length of the decision thresholds {decision_thresholds} " 318 f"must be equal to the length of the behaviors dictionary " 319 f"{self._episode(load_episode).get_behaviors_dict()}" 320 ) 321 data_param_update = {} 322 if data_path is not None: 323 data_param_update = {"data_path": data_path} 324 if annotation_path is None: 325 data_param_update["annotation_path"] = data_path 326 if annotation_path is not None: 327 data_param_update["annotation_path"] = annotation_path 328 if file_paths is not None: 329 data_param_update = {"data_path": None, "file_paths": file_paths} 330 parameters_update = self._update(parameters_update, {"data": data_param_update}) 331 if data_path is not None or file_paths is not None: 332 general_update = { 333 "annotation_type": annotation_type, 334 "only_load_annotated": False, 335 } 336 else: 337 general_update = {} 338 parameters_update = self._update(parameters_update, {"general": general_update}) 339 task, parameters = self._make_task( 340 episode_name=prediction_name, 341 load_episode=load_episode, 342 parameters_update=parameters_update, 343 parameters_update_second=parameters_update_second, 344 load_epoch=load_epoch, 345 purpose="prediction", 346 task=task, 347 ) 348 return task, parameters, mode 349 350 def _make_task_training( 351 self, 352 episode_name: str, 353 load_episode: str = None, 354 parameters_update: Dict = None, 355 load_epoch: int = None, 356 load_search: str = None, 357 load_parameters: list = None, 358 round_to_binary: list = None, 359 load_strict: bool = True, 360 continuing: bool = False, 361 task: TaskDispatcher = None, 362 mask_name: str = None, 363 throwaway: bool = False, 364 ) -> Tuple[TaskDispatcher, Dict, str]: 365 """Make a `TaskDispatcher` object that will be used to generate a prediction.""" 366 if parameters_update is None: 367 parameters_update = {} 368 if continuing: 369 purpose = "continuing" 370 else: 371 purpose = "training" 372 if mask_name is not None: 373 mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle") 374 parameters_update_second = {"data": {"real_lens": mask_name}} 375 if throwaway: 376 parameters_update = self._update( 377 parameters_update, {"training": {"normalize": False, "device": "cpu"}} 378 ) 379 return self._make_task( 380 episode_name, 381 load_episode, 382 parameters_update, 383 parameters_update_second, 384 load_epoch, 385 load_search, 386 load_parameters, 387 round_to_binary, 388 purpose, 389 task, 390 load_strict=load_strict, 391 ) 392 393 def _make_parameters( 394 self, 395 episode_name: str, 396 load_episode: str = None, 397 parameters_update: Dict = None, 398 parameters_update_second: Dict = None, 399 load_epoch: int = None, 400 load_search: str = None, 401 load_parameters: list = None, 402 round_to_binary: list = None, 403 purpose: str = "train", 404 load_strict: bool = True, 405 ): 406 """Construct a parameters dictionary.""" 407 if parameters_update is None: 408 parameters_update = {} 409 pars_update = deepcopy(parameters_update) 410 if parameters_update_second is None: 411 parameters_update_second = {} 412 if ( 413 purpose == "prediction" 414 and "model" in pars_update.keys() 415 and pars_update["general"]["model_name"] != "motionbert" 416 ): 417 raise ValueError("Cannot change model parameters after training!") 418 if purpose in ["continuing", "prediction"] and load_episode is not None: 419 read_parameters = self._read_parameters() 420 parameters = self._episodes().load_parameters(load_episode) 421 parameters["metrics"] = self._update( 422 read_parameters["metrics"], parameters["metrics"] 423 ) 424 parameters["ssl"] = self._update( 425 read_parameters["ssl"], parameters.get("ssl", {}) 426 ) 427 else: 428 parameters = self._read_parameters() 429 if "model" in pars_update: 430 model_params = pars_update.pop("model") 431 else: 432 model_params = None 433 if "features" in pars_update: 434 feat_params = pars_update.pop("features") 435 else: 436 feat_params = None 437 if "augmentations" in pars_update: 438 aug_params = pars_update.pop("augmentations") 439 else: 440 aug_params = None 441 parameters = self._update(parameters, pars_update) 442 if pars_update.get("general", {}).get("model_name") is not None: 443 model_name = parameters["general"]["model_name"] 444 parameters["model"] = self._open_yaml( 445 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 446 ) 447 if pars_update.get("general", {}).get("feature_extraction") is not None: 448 feat_name = parameters["general"]["feature_extraction"] 449 parameters["features"] = self._open_yaml( 450 os.path.join( 451 self.project_path, "config", "features", f"{feat_name}.yaml" 452 ) 453 ) 454 aug_name = options.extractor_to_transformer[ 455 parameters["general"]["feature_extraction"] 456 ] 457 parameters["augmentations"] = self._open_yaml( 458 os.path.join( 459 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 460 ) 461 ) 462 if model_params is not None: 463 parameters["model"] = self._update(parameters["model"], model_params) 464 if feat_params is not None: 465 parameters["features"] = self._update(parameters["features"], feat_params) 466 if aug_params is not None: 467 parameters["augmentations"] = self._update( 468 parameters["augmentations"], aug_params 469 ) 470 if load_search is not None: 471 parameters = self._update_with_search( 472 parameters, load_search, load_parameters, round_to_binary 473 ) 474 parameters = self._fill( 475 parameters, 476 episode_name, 477 load_episode, 478 load_epoch=load_epoch, 479 load_strict=load_strict, 480 only_load_model=(purpose != "continuing"), 481 continuing=(purpose in ["prediction", "continuing"]), 482 enforce_split_parameters=(purpose == "prediction"), 483 ) 484 parameters = self._update(parameters, parameters_update_second) 485 return parameters 486 487 def _make_task( 488 self, 489 episode_name: str, 490 load_episode: str = None, 491 parameters_update: Dict = None, 492 parameters_update_second: Dict = None, 493 load_epoch: int = None, 494 load_search: str = None, 495 load_parameters: list = None, 496 round_to_binary: list = None, 497 purpose: str = "train", 498 task: TaskDispatcher = None, 499 load_strict: bool = True, 500 ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]: 501 """Make a `TaskDispatcher` object. 502 503 The task parameters are read from the config files and then updated with the 504 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 505 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 506 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 507 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 508 data parameters are used. 509 510 Parameters 511 ---------- 512 episode_name : str 513 the name of the episode 514 load_episode : str, optional 515 the (previously run) episode name to load the model from 516 parameters_update : dict, optional 517 the dictionary used to update the parameters from the config 518 parameters_update_second : dict, optional 519 the dictionary used to update the parameters after the automatic fill-out 520 load_epoch : int, optional 521 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 522 load_search : str, optional 523 the hyperparameter search result to load 524 load_parameters : list, optional 525 a list of string names of the parameters to load from load_search (if not provided, all parameters 526 are loaded) 527 round_to_binary : list, optional 528 a list of string names of the loaded parameters that should be rounded to the nearest power of two 529 purpose : {"train", "continuing", "prediction"} 530 the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing 531 the training of an interrupted episode, `"prediction"` for generating a prediction) 532 task : TaskDispatcher, optional 533 a pre-existing task; if provided, the method will update the task instead of creating a new one 534 (this might save time, mainly on dataset loading) 535 536 Returns 537 ------- 538 task : TaskDispatcher 539 the `TaskDispatcher` instance 540 parameters : dict 541 the parameters dictionary that describes the task 542 543 """ 544 parameters = self._make_parameters( 545 episode_name, 546 load_episode, 547 parameters_update, 548 parameters_update_second, 549 load_epoch, 550 load_search, 551 load_parameters, 552 round_to_binary, 553 purpose, 554 load_strict=load_strict, 555 ) 556 if task is None: 557 task = TaskDispatcher(parameters) 558 else: 559 task.update_task(parameters) 560 self._save_stores(parameters) 561 return task, parameters 562 563 def get_decision_thresholds( 564 self, 565 episode_names: List, 566 metric_name: str = "f1", 567 parameters_update: Dict = None, 568 load_epochs: List = None, 569 remove_saved_features: bool = False, 570 ) -> Tuple[List, List, TaskDispatcher]: 571 """Compute optimal decision thresholds or load them if they have been computed before. 572 573 Parameters 574 ---------- 575 episode_names : List 576 a list of episode names 577 metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"} 578 the metric to optimize 579 parameters_update : dict, optional 580 the parameter update dictionary 581 load_epochs : list, optional 582 a list of epochs to load (by default last are loaded) 583 remove_saved_features : bool, default False 584 if `True`, the dataset will be deleted after the computation 585 586 Returns 587 ------- 588 thresholds : list 589 a list of float decision threshold values 590 classes : list 591 the label names corresponding to the values 592 task : TaskDispatcher | None 593 the task used in computation 594 595 """ 596 parameters = self._make_parameters( 597 "_", 598 episode_names[0], 599 parameters_update, 600 {}, 601 load_epochs[0], 602 purpose="prediction", 603 ) 604 thresholds = self._thresholds().find_thresholds( 605 episode_names, 606 load_epochs, 607 metric_name, 608 metric_parameters=parameters["metrics"][metric_name], 609 ) 610 task = None 611 behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values()) 612 return thresholds, behaviors, task 613 614 def run_episode( 615 self, 616 episode_name: str, 617 load_episode: str = None, 618 parameters_update: Dict = None, 619 task: TaskDispatcher = None, 620 load_epoch: int = None, 621 load_search: str = None, 622 load_parameters: list = None, 623 round_to_binary: list = None, 624 load_strict: bool = True, 625 n_seeds: int = 1, 626 force: bool = False, 627 suppress_name_check: bool = False, 628 remove_saved_features: bool = False, 629 mask_name: str = None, 630 autostop_metric: str = None, 631 autostop_interval: int = 50, 632 autostop_threshold: float = 0.001, 633 loading_bar: bool = False, 634 trial: Tuple = None, 635 ) -> TaskDispatcher: 636 """Run an episode. 637 638 The task parameters are read from the config files and then updated with the 639 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 640 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 641 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 642 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 643 data parameters are used. 644 645 You can use the autostop parameters to finish training when the parameters are not improving. It will be 646 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 647 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 648 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 649 650 Parameters 651 ---------- 652 episode_name : str 653 the episode name 654 load_episode : str, optional 655 the (previously run) episode name to load the model from; if the episode has multiple runs, 656 the new episode will have the same number of runs, each starting with one of the pre-trained models 657 parameters_update : dict, optional 658 the dictionary used to update the parameters from the config files 659 task : TaskDispatcher, optional 660 a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating 661 a new instance) 662 load_epoch : int, optional 663 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 664 load_search : str, optional 665 the hyperparameter search result to load 666 load_parameters : list, optional 667 a list of string names of the parameters to load from load_search (if not provided, all parameters 668 are loaded) 669 round_to_binary : list, optional 670 a list of string names of the loaded parameters that should be rounded to the nearest power of two 671 load_strict : bool, default True 672 if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and 673 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` 674 n_seeds : int, default 1 675 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g. 676 `test_episode#0` and `test_episode#1` 677 force : bool, default False 678 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 679 suppress_name_check : bool, default False 680 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 681 why they are usually forbidden) 682 remove_saved_features : bool, default False 683 if `True`, the dataset will be deleted after training 684 mask_name : str, optional 685 the name of the real_lens to apply 686 autostop_metric : str, optional 687 the autostop metric (can be any one of the tracked metrics of `'loss'`) 688 autostop_interval : int, default 50 689 the number of epochs to average the autostop metric over 690 autostop_threshold : float, default 0.001 691 the autostop difference threshold 692 loading_bar : bool, default False 693 if `True`, a loading bar will be displayed 694 trial : tuple, optional 695 a tuple of (trial, metric) for hyperparameter search 696 697 Returns 698 ------- 699 TaskDispatcher 700 the `TaskDispatcher` object 701 702 """ 703 704 import gc 705 706 gc.collect() 707 if torch.cuda.is_available(): 708 torch.cuda.empty_cache() 709 710 if type(n_seeds) is not int or n_seeds < 1: 711 raise ValueError( 712 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 713 ) 714 if n_seeds > 1 and mask_name is not None: 715 raise ValueError("Cannot apply a real_lens with n_seeds > 1") 716 self._check_episode_validity( 717 episode_name, allow_doublecolon=suppress_name_check, force=force 718 ) 719 load_runs = self._episodes().get_runs(load_episode) 720 if len(load_runs) > 1: 721 task = self.run_episodes( 722 episode_names=[ 723 f'{episode_name}#{run.split("#")[-1]}' for run in load_runs 724 ], 725 load_episodes=load_runs, 726 parameters_updates=[parameters_update for _ in load_runs], 727 load_epochs=[load_epoch for _ in load_runs], 728 load_searches=[load_search for _ in load_runs], 729 load_parameters=[load_parameters for _ in load_runs], 730 round_to_binary=[round_to_binary for _ in load_runs], 731 load_strict=[load_strict for _ in load_runs], 732 suppress_name_check=True, 733 force=force, 734 remove_saved_features=False, 735 ) 736 if remove_saved_features: 737 self._remove_stores( 738 { 739 "general": task.general_parameters, 740 "data": task.data_parameters, 741 "features": task.feature_parameters, 742 } 743 ) 744 if n_seeds > 1: 745 warnings.warn( 746 f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs" 747 ) 748 elif n_seeds > 1: 749 750 self.run_episodes( 751 episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)], 752 load_episodes=[load_episode for _ in range(n_seeds)], 753 parameters_updates=[parameters_update for _ in range(n_seeds)], 754 load_epochs=[load_epoch for _ in range(n_seeds)], 755 load_searches=[load_search for _ in range(n_seeds)], 756 load_parameters=[load_parameters for _ in range(n_seeds)], 757 round_to_binary=[round_to_binary for _ in range(n_seeds)], 758 load_strict=[load_strict for _ in range(n_seeds)], 759 suppress_name_check=True, 760 force=force, 761 remove_saved_features=remove_saved_features, 762 ) 763 else: 764 print(f"TRAINING {episode_name}") 765 try: 766 task, parameters = self._make_task_training( 767 episode_name, 768 load_episode, 769 parameters_update, 770 load_epoch, 771 load_search, 772 load_parameters, 773 round_to_binary, 774 continuing=False, 775 task=task, 776 mask_name=mask_name, 777 load_strict=load_strict, 778 ) 779 self._save_episode( 780 episode_name, 781 parameters, 782 task.behaviors_dict(), 783 norm_stats=task.get_normalization_stats(), 784 ) 785 time_start = time.time() 786 if trial is not None: 787 trial, metric = trial 788 else: 789 trial, metric = None, None 790 logs = task.train( 791 autostop_metric=autostop_metric, 792 autostop_interval=autostop_interval, 793 autostop_threshold=autostop_threshold, 794 loading_bar=loading_bar, 795 trial=trial, 796 optimized_metric=metric, 797 ) 798 time_end = time.time() 799 time_total = time_end - time_start 800 hours = int(time_total // 3600) 801 time_total -= hours * 3600 802 minutes = int(time_total // 60) 803 time_total -= minutes * 60 804 seconds = int(time_total) 805 training_time = f"{hours}:{minutes:02}:{seconds:02}" 806 self._update_episode_results(episode_name, logs, training_time) 807 if remove_saved_features: 808 self._remove_stores(parameters) 809 print("\n") 810 return task 811 812 except Exception as e: 813 if isinstance(e, optuna.exceptions.TrialPruned): 814 raise e 815 else: 816 # if str(e) != f"The {episode_name} episode name is already in use!": 817 # self.remove_episode(episode_name) 818 raise RuntimeError(f"Episode {episode_name} could not run") 819 820 def run_episodes( 821 self, 822 episode_names: List, 823 load_episodes: List = None, 824 parameters_updates: List = None, 825 load_epochs: List = None, 826 load_searches: List = None, 827 load_parameters: List = None, 828 round_to_binary: List = None, 829 load_strict: List = None, 830 force: bool = False, 831 suppress_name_check: bool = False, 832 remove_saved_features: bool = False, 833 ) -> TaskDispatcher: 834 """Run multiple episodes in sequence (and re-use previously loaded information). 835 836 For each episode, the task parameters are read from the config files and then updated with the 837 parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the 838 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 839 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 840 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 841 data parameters are used. 842 843 Parameters 844 ---------- 845 episode_names : list 846 a list of strings of episode names 847 load_episodes : list, optional 848 a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, 849 the new episode will have the same number of runs, each starting with one of the pre-trained models 850 parameters_updates : list, optional 851 a list of dictionaries used to update the parameters from the config 852 load_epochs : list, optional 853 a list of integers used to specify the epoch to load (if load_episodes is not None) 854 load_searches : list, optional 855 a list of strings of hyperparameter search results to load 856 load_parameters : list, optional 857 a list of lists of string names of the parameters to load from the searches 858 round_to_binary : list, optional 859 a list of string names of the loaded parameters that should be rounded to the nearest power of two 860 load_strict : list, optional 861 a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from 862 the corresponding episode and differences in parameter name lists and 863 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for 864 every episode) 865 force : bool, default False 866 if `True` and an episode name is already taken, it will be overwritten (use with caution!) 867 suppress_name_check : bool, default False 868 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 869 why they are usually forbidden) 870 remove_saved_features : bool, default False 871 if `True`, the dataset will be deleted after training 872 873 Returns 874 ------- 875 TaskDispatcher 876 the task dispatcher object 877 878 """ 879 task = None 880 if load_searches is None: 881 load_searches = [None for _ in episode_names] 882 if load_episodes is None: 883 load_episodes = [None for _ in episode_names] 884 if parameters_updates is None: 885 parameters_updates = [None for _ in episode_names] 886 if load_parameters is None: 887 load_parameters = [None for _ in episode_names] 888 if load_epochs is None: 889 load_epochs = [None for _ in episode_names] 890 if load_strict is None: 891 load_strict = [True for _ in episode_names] 892 for ( 893 parameters_update, 894 episode_name, 895 load_episode, 896 load_epoch, 897 load_search, 898 load_parameters_list, 899 load_strict_value, 900 ) in zip( 901 parameters_updates, 902 episode_names, 903 load_episodes, 904 load_epochs, 905 load_searches, 906 load_parameters, 907 load_strict, 908 ): 909 task = self.run_episode( 910 episode_name, 911 load_episode, 912 parameters_update, 913 task, 914 load_epoch, 915 load_search, 916 load_parameters_list, 917 round_to_binary, 918 load_strict_value, 919 suppress_name_check=suppress_name_check, 920 force=force, 921 remove_saved_features=remove_saved_features, 922 ) 923 return task 924 925 def continue_episode( 926 self, 927 episode_name: str, 928 num_epochs: int = None, 929 task: TaskDispatcher = None, 930 n_seeds: int = 1, 931 remove_saved_features: bool = False, 932 device: str = "cuda", 933 num_cpus: int = None, 934 ) -> TaskDispatcher: 935 """Load an older episode and continue running from the latest checkpoint. 936 937 All parameters as well as the model and optimizer state dictionaries are loaded from the episode. 938 939 Parameters 940 ---------- 941 episode_name : str 942 the name of the episode to continue 943 num_epochs : int, optional 944 the new number of epochs 945 task : TaskDispatcher, optional 946 a pre-existing task; if provided, the method will update the task instead of creating a new one 947 (this might save time, mainly on dataset loading) 948 n_seeds : int, default 1 949 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g. 950 `test_episode#0` and `test_episode#1` 951 remove_saved_features : bool, default False 952 if `True`, pre-computed features will be deleted after the run 953 device : str, default "cuda" 954 the torch device to use 955 num_cpus : int, optional 956 the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used 957 958 Returns 959 ------- 960 TaskDispatcher 961 the task dispatcher 962 963 """ 964 runs = self._episodes().get_runs(episode_name) 965 for run in runs: 966 print(f"TRAINING {run}") 967 if num_epochs is None and not self._episode(run).unfinished(): 968 continue 969 parameters_update = { 970 "training": { 971 "num_epochs": num_epochs, 972 "device": device, 973 }, 974 "general": {"num_cpus": num_cpus}, 975 } 976 task, parameters = self._make_task_training( 977 run, 978 load_episode=run, 979 parameters_update=parameters_update, 980 continuing=True, 981 task=task, 982 ) 983 time_start = time.time() 984 logs = task.train() 985 time_end = time.time() 986 old_time = self._training_time(run) 987 if not np.isnan(old_time): 988 time_end += old_time 989 time_total = time_end - time_start 990 hours = int(time_total // 3600) 991 time_total -= hours * 3600 992 minutes = int(time_total // 60) 993 time_total -= minutes * 60 994 seconds = int(time_total) 995 training_time = f"{hours}:{minutes:02}:{seconds:02}" 996 else: 997 training_time = np.nan 998 self._save_episode( 999 run, 1000 parameters, 1001 task.behaviors_dict(), 1002 suppress_validation=True, 1003 training_time=training_time, 1004 norm_stats=task.get_normalization_stats(), 1005 ) 1006 self._update_episode_results(run, logs) 1007 print("\n") 1008 if len(runs) < n_seeds: 1009 for i in range(len(runs), n_seeds): 1010 self.run_episode( 1011 f"{episode_name}#{i}", 1012 parameters_update=self._episodes().load_parameters(runs[0]), 1013 task=task, 1014 suppress_name_check=True, 1015 ) 1016 if remove_saved_features: 1017 self._remove_stores(parameters) 1018 return task 1019 1020 def run_default_hyperparameter_search( 1021 self, 1022 search_name: str, 1023 model_name: str, 1024 metric: str = "f1", 1025 best_n: int = 3, 1026 direction: str = "maximize", 1027 load_episode: str = None, 1028 load_epoch: int = None, 1029 load_strict: bool = True, 1030 prune: bool = True, 1031 force: bool = False, 1032 remove_saved_features: bool = False, 1033 overlap: float = 0, 1034 num_epochs: int = 50, 1035 test_frac: float = None, 1036 n_trials=150, 1037 batch_size=32, 1038 ): 1039 """Run an optuna hyperparameter search with default parameters for a model. 1040 1041 For the vast majority of cases, optimizing the default parameters should be enough. 1042 Check out `dlc2action.options.model_hyperparameters` for the lists of parameters. 1043 There are also options to set overlap, test fraction and number of epochs parameters for the search without 1044 modifying the project config files. However, if you want something more complex, look into 1045 `Project.run_hyperparameter_search`. 1046 1047 The task parameters are read from the config files and updated with the parameters_update dictionary. 1048 The model can be either initialized from scratch or loaded from a previously run episode. 1049 For each trial, the objective metric is averaged over a few best epochs. 1050 1051 Parameters 1052 ---------- 1053 search_name : str 1054 the name of the search to store it in the meta files and load in run_episode 1055 model_name : str 1056 the name 1057 metric : str 1058 the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to 1059 `"none"` in the config files, it will be reset to `"macro"` for the search 1060 best_n : int, default 1 1061 the number of epochs to average the metric; if 0, the last value is taken 1062 direction : {'maximize', 'minimize'} 1063 optimization direction 1064 load_episode : str, optional 1065 the name of the episode to load the model from 1066 load_epoch : int, optional 1067 the epoch to load the model from (if not provided, the last checkpoint is used) 1068 load_strict : bool, default True 1069 if `True`, the model will be loaded only if the parameters match exactly 1070 prune : bool, default False 1071 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1072 (with optuna HyperBand pruner) 1073 force : bool, default False 1074 if `True`, existing searches with the same name will be overwritten 1075 remove_saved_features : bool, default False 1076 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1077 overlap : float, default 0 1078 the overlap to use for the search 1079 num_epochs : int, default 50 1080 the number of epochs to use for the search 1081 test_frac : float, optional 1082 the test fraction to use for the search 1083 n_trials : int, default 150 1084 the number of trials to run 1085 batch_size : int, default 32 1086 the batch size to use for the search 1087 1088 Returns 1089 ------- 1090 best_parameters : dict 1091 a dictionary of best parameters 1092 1093 """ 1094 if model_name not in options.model_hyperparameters: 1095 raise ValueError( 1096 f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()" 1097 ) 1098 pars = { 1099 "general": {"overlap": overlap, "model_name": model_name}, 1100 "training": {"num_epochs": num_epochs, "batch_size": batch_size}, 1101 } 1102 if test_frac is not None: 1103 pars["training"]["test_frac"] = test_frac 1104 if not metric.split("_")[-1].isnumeric(): 1105 project_pars = self._read_parameters() 1106 if project_pars["metrics"][metric].get("average") == "none": 1107 pars["metrics"] = {metric: {"average": "macro"}} 1108 return self.run_hyperparameter_search( 1109 search_name=search_name, 1110 search_space=options.model_hyperparameters[model_name], 1111 metric=metric, 1112 n_trials=n_trials, 1113 best_n=best_n, 1114 parameters_update=pars, 1115 direction=direction, 1116 load_episode=load_episode, 1117 load_epoch=load_epoch, 1118 load_strict=load_strict, 1119 prune=prune, 1120 force=force, 1121 remove_saved_features=remove_saved_features, 1122 ) 1123 1124 def run_hyperparameter_search( 1125 self, 1126 search_name: str, 1127 search_space: Dict, 1128 metric: str = "f1", 1129 n_trials: int = 20, 1130 best_n: int = 1, 1131 parameters_update: Dict = None, 1132 direction: str = "maximize", 1133 load_episode: str = None, 1134 load_epoch: int = None, 1135 load_strict: bool = True, 1136 prune: bool = False, 1137 force: bool = False, 1138 remove_saved_features: bool = False, 1139 make_plots: bool = True, 1140 ) -> Dict: 1141 """Run an optuna hyperparameter search. 1142 1143 For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`. 1144 1145 To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is 1146 a dictionary where keys are model names and values are default search spaces. 1147 1148 The task parameters are read from the config files and updated with the parameters_update dictionary. 1149 The model can be either initialized from scratch or loaded from a previously run episode. 1150 For each trial, the objective metric is averaged over a few best epochs. 1151 1152 Parameters 1153 ---------- 1154 search_name : str 1155 the name of the search to store it in the meta files and load in run_episode 1156 search_space : dict 1157 a dictionary representing the search space; of this general structure: 1158 {'group/param_name': ('float/int/float_log/int_log', start, end), 1159 'group/param_name': ('categorical', [choices])}, e.g. 1160 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 1161 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}; 1162 metric : str, default f1 1163 the metric to maximize/minimize (see direction) 1164 n_trials : int, default 20 1165 the number of optimization trials to run 1166 best_n : int, default 1 1167 the number of epochs to average the metric; if 0, the last value is taken 1168 parameters_update : dict, optional 1169 the parameters update dictionary 1170 direction : {'maximize', 'minimize'} 1171 optimization direction 1172 load_episode : str, optional 1173 the name of the episode to load the model from 1174 load_epoch : int, optional 1175 the epoch to load the model from (if not provided, the last checkpoint is used) 1176 load_strict : bool, default True 1177 if `True`, the model will be loaded only if the parameters match exactly 1178 prune : bool, default False 1179 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1180 (with optuna HyperBand pruner) 1181 force : bool, default False 1182 if `True`, existing searches with the same name will be overwritten 1183 remove_saved_features : bool, default False 1184 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1185 1186 Returns 1187 ------- 1188 dict 1189 a dictionary of best parameters 1190 1191 """ 1192 self._check_search_validity(search_name, force=force) 1193 print(f"SEARCH {search_name}") 1194 self.remove_episode(f"_{search_name}") 1195 if parameters_update is None: 1196 parameters_update = {} 1197 parameters_update = self._update( 1198 parameters_update, {"general": {"metric_functions": {metric}}} 1199 ) 1200 parameters = self._make_parameters( 1201 f"_{search_name}", 1202 load_episode, 1203 parameters_update, 1204 parameters_update_second={"training": {"model_save_path": None}}, 1205 load_epoch=load_epoch, 1206 load_strict=load_strict, 1207 ) 1208 task = None 1209 1210 if prune: 1211 pruner = optuna.pruners.HyperbandPruner() 1212 else: 1213 pruner = optuna.pruners.NopPruner() 1214 study = optuna.create_study(direction=direction, pruner=pruner) 1215 runner = _Runner( 1216 search_space=search_space, 1217 load_episode=load_episode, 1218 load_epoch=load_epoch, 1219 metric=metric, 1220 average=best_n, 1221 task=task, 1222 remove_saved_features=remove_saved_features, 1223 project=self, 1224 search_name=search_name, 1225 ) 1226 study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials) 1227 if make_plots: 1228 search_path = self._search_path(search_name) 1229 os.mkdir(search_path) 1230 fig = optuna.visualization.plot_contour(study) 1231 plotly.offline.plot( 1232 fig, filename=os.path.join(search_path, f"{search_name}_contour.html") 1233 ) 1234 fig = optuna.visualization.plot_param_importances(study) 1235 plotly.offline.plot( 1236 fig, 1237 filename=os.path.join(search_path, f"{search_name}_importances.html"), 1238 ) 1239 best_params = study.best_params 1240 best_value = study.best_value 1241 if best_value == 0 or best_value == float("inf"): 1242 raise ValueError( 1243 f"Best metric value is {best_value}, check your partition method and make sure that all behaviors are present in the validation set!" 1244 ) 1245 self._save_search( 1246 search_name, 1247 parameters, 1248 n_trials, 1249 best_params, 1250 best_value, 1251 metric, 1252 search_space, 1253 ) 1254 self.remove_episode(f"_{search_name}") 1255 runner.clean() 1256 print(f"best parameters: {best_params}") 1257 print("\n") 1258 return best_params 1259 1260 def run_prediction( 1261 self, 1262 prediction_name: str, 1263 episode_names: List, 1264 load_epochs: List = None, 1265 parameters_update: Dict = None, 1266 augment_n: int = 10, 1267 data_path: str = None, 1268 mode: str = "all", 1269 file_paths: Set = None, 1270 remove_saved_features: bool = False, 1271 frame_number_map_file: str = None, 1272 force: bool = False, 1273 embedding: bool = False, 1274 ) -> None: 1275 """Load models from previously run episodes to generate a prediction. 1276 1277 The probabilities predicted by the models are averaged. 1278 Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder 1279 under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level 1280 keys are the video ids, the second-level keys are the clip ids (like individual names) and the values 1281 are the prediction arrays. 1282 1283 Parameters 1284 ---------- 1285 prediction_name : str 1286 the name of the prediction 1287 episode_names : list 1288 a list of string episode names to load the models from 1289 load_epochs : list or int, optional 1290 a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes 1291 parameters_update : dict, optional 1292 a dictionary of parameter updates 1293 augment_n : int, default 10 1294 the number of augmentations to average over 1295 data_path : str, optional 1296 the data path to run the prediction for 1297 mode : {'all', 'test', 'val', 'train'} 1298 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1299 file_paths : set, optional 1300 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1301 for 1302 remove_saved_features : bool, default False 1303 if `True`, pre-computed features will be deleted 1304 submission : bool, default False 1305 if `True`, a MABe-22 style submission file is generated 1306 frame_number_map_file : str, optional 1307 path to the frame number map file 1308 force : bool, default False 1309 if `True`, existing prediction with this name will be overwritten 1310 embedding : bool, default False 1311 if `True`, the prediction is made for the embedding task 1312 1313 """ 1314 self._check_prediction_validity(prediction_name, force=force) 1315 print(f"PREDICTION {prediction_name}") 1316 task, parameters, mode, prediction, inference_time, behavior_dict = ( 1317 self._make_prediction( 1318 prediction_name, 1319 episode_names, 1320 load_epochs, 1321 parameters_update, 1322 data_path, 1323 file_paths, 1324 mode, 1325 augment_n, 1326 evaluate=False, 1327 embedding=embedding, 1328 ) 1329 ) 1330 predicted = task.dataset(mode).generate_full_length_prediction(prediction) 1331 1332 if remove_saved_features: 1333 self._remove_stores(parameters) 1334 1335 self._save_prediction( 1336 prediction_name, 1337 predicted, 1338 parameters, 1339 task, 1340 mode, 1341 embedding, 1342 inference_time, 1343 behavior_dict, 1344 ) 1345 print("\n") 1346 1347 def evaluate_prediction( 1348 self, 1349 prediction_name: str, 1350 parameters_update: Dict = None, 1351 data_path: str = None, 1352 annotation_path: str = None, 1353 file_paths: Set = None, 1354 mode: str = None, 1355 remove_saved_features: bool = False, 1356 annotation_type: str = "none", 1357 num_classes: int = None, # Set when using data_path 1358 ) -> Tuple[float, dict]: 1359 """Make predictions and evaluate them 1360 inputs: 1361 prediction_name (str): the name of the prediction 1362 parameters_update (dict): a dictionary of parameter updates 1363 data_path (str): the data path to run the prediction for 1364 annotation_path (str): the annotation path to run the prediction for 1365 file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for 1366 mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1367 remove_saved_features (bool): if `True`, pre-computed features will be deleted 1368 annotation_type (str): the type of annotation to use for evaluation 1369 num_classes (int): the number of classes in the dataset, must be set with data_path 1370 outputs: 1371 results (dict): a dictionary of average values of metric functions 1372 """ 1373 1374 prediction_path = os.path.join( 1375 self.project_path, "results", "predictions", f"{prediction_name}" 1376 ) 1377 prediction_dict = {} 1378 for prediction_file_path in [ 1379 os.path.join(prediction_path, i) for i in os.listdir(prediction_path) 1380 ]: 1381 with open(os.path.join(prediction_file_path), "rb") as f: 1382 prediction = pickle.load(f) 1383 video_id = os.path.basename(prediction_file_path).split( 1384 "_" + prediction_name 1385 )[0] 1386 prediction_dict[video_id] = prediction 1387 if parameters_update is None: 1388 parameters_update = {} 1389 parameters_update = self._update( 1390 self._predictions().load_parameters(prediction_name), parameters_update 1391 ) 1392 parameters_update.pop("model") 1393 if not data_path is None: 1394 assert ( 1395 not num_classes is None 1396 ), "num_classes must be provided if data_path is provided" 1397 parameters_update["general"]["num_classes"] = num_classes + int( 1398 parameters_update["general"]["exclusive"] 1399 ) 1400 task, parameters, mode = self._make_task_prediction( 1401 "_", 1402 load_episode=None, 1403 parameters_update=parameters_update, 1404 data_path=data_path, 1405 annotation_path=annotation_path, 1406 file_paths=file_paths, 1407 mode=mode, 1408 annotation_type=annotation_type, 1409 ) 1410 results = task.evaluate_prediction(prediction_dict, data=mode) 1411 if remove_saved_features: 1412 self._remove_stores(parameters) 1413 results = Project._reformat_results( 1414 results[1], 1415 task.behaviors_dict(), 1416 exclusive=task.general_parameters["exclusive"], 1417 ) 1418 return results 1419 1420 def evaluate( 1421 self, 1422 episode_names: List, 1423 load_epochs: List = None, 1424 augment_n: int = 0, 1425 data_path: str = None, 1426 file_paths: Set = None, 1427 mode: str = None, 1428 parameters_update: Dict = None, 1429 multiple_episode_policy: str = "average", 1430 remove_saved_features: bool = False, 1431 skip_updating_meta: bool = True, 1432 annotation_type: str = "none", 1433 ) -> Dict: 1434 """Load one or several models from previously run episodes to make an evaluation. 1435 1436 By default it will run on the test (or validation, if there is no test) subset of the project dataset. 1437 1438 Parameters 1439 ---------- 1440 episode_names : list 1441 a list of string episode names to load the models from 1442 load_epochs : list, optional 1443 a list of integer epoch indices to load the model from; if None, the last ones are used 1444 augment_n : int, default 0 1445 the number of augmentations to average over 1446 data_path : str, optional 1447 the data path to run the prediction for 1448 file_paths : set, optional 1449 a set of files to run the prediction for 1450 mode : {'test', 'val', 'train', 'all'} 1451 the subset of the data to make the prediction for (forced to 'all' if data_path is not None; 1452 by default 'test' if test subset is not empty and 'val' otherwise) 1453 parameters_update : dict, optional 1454 a dictionary with parameter updates (cannot change model parameters) 1455 multiple_episode_policy : {'average', 'statistics'} 1456 the policy to use when multiple episodes are provided 1457 remove_saved_features : bool, default False 1458 if `True`, the dataset will be deleted 1459 skip_updating_meta : bool, default True 1460 if `True`, the meta file will not be updated with the computed metrics 1461 1462 Returns 1463 ------- 1464 metric : dict 1465 a dictionary of average values of metric functions 1466 1467 """ 1468 names = [] 1469 for episode_name in episode_names: 1470 names += self._episodes().get_runs(episode_name) 1471 if len(set(episode_names)) == 1: 1472 print(f"EVALUATION {episode_names[0]}") 1473 else: 1474 print(f"EVALUATION {episode_names}") 1475 if len(names) > 1: 1476 evaluate = True 1477 else: 1478 evaluate = False 1479 if multiple_episode_policy == "average": 1480 task, parameters, mode, prediction, inference_time, behavior_dict = ( 1481 self._make_prediction( 1482 "_", 1483 episode_names, 1484 load_epochs, 1485 parameters_update, 1486 mode=mode, 1487 data_path=data_path, 1488 file_paths=file_paths, 1489 augment_n=augment_n, 1490 evaluate=evaluate, 1491 annotation_type=annotation_type, 1492 ) 1493 ) 1494 print("EVALUATE PREDICTION:") 1495 indices = [ 1496 list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict)) 1497 ] 1498 _, results = task.evaluate_prediction( 1499 prediction, data=mode, indices=indices 1500 ) 1501 if len(names) == 1 and mode == "val" and not skip_updating_meta: 1502 self._update_episode_metrics(names[0], results) 1503 results = Project._reformat_results( 1504 results, 1505 behavior_dict, 1506 exclusive=task.general_parameters["exclusive"], 1507 ) 1508 1509 elif multiple_episode_policy == "statistics": 1510 values = defaultdict(lambda: []) 1511 task = None 1512 for name in names: 1513 ( 1514 task, 1515 parameters, 1516 mode, 1517 prediction, 1518 inference_time, 1519 behavior_dict, 1520 ) = self._make_prediction( 1521 "_", 1522 [name], 1523 load_epochs, 1524 parameters_update, 1525 mode=mode, 1526 data_path=data_path, 1527 file_paths=file_paths, 1528 augment_n=augment_n, 1529 evaluate=evaluate, 1530 task=task, 1531 ) 1532 _, metrics = task.evaluate_prediction( 1533 prediction, data=mode, indices=list(behavior_dict.keys()) 1534 ) 1535 for name, value in metrics.items(): 1536 values[name].append(value) 1537 if mode == "val" and not skip_updating_meta: 1538 self._update_episode_metrics(name, metrics) 1539 results = defaultdict(lambda: {}) 1540 mean_string = "" 1541 std_string = "" 1542 for key, value_list in values.items(): 1543 results[key]["mean"] = np.mean(value_list) 1544 results[key]["std"] = np.std(value_list) 1545 results[key]["all"] = value_list 1546 mean_string += f"{key} {np.mean(value_list):.3f}, " 1547 std_string += f"{key} {np.std(value_list):.3f}, " 1548 print("MEAN:") 1549 print(mean_string) 1550 print("STD:") 1551 print(std_string) 1552 else: 1553 raise ValueError( 1554 f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose " 1555 f"from ['average', 'statistics']" 1556 ) 1557 if len(names) > 0 and remove_saved_features: 1558 self._remove_stores(parameters) 1559 print(f"Inference time: {inference_time}") 1560 print("\n") 1561 return results 1562 1563 def run_suggestion( 1564 self, 1565 suggestions_name: str, 1566 error_episode: str = None, 1567 error_load_epoch: int = None, 1568 error_class: str = None, 1569 suggestions_prediction: str = None, 1570 suggestion_episodes: List = [None], 1571 suggestion_load_epoch: int = None, 1572 suggestion_classes: List = None, 1573 error_threshold: float = 0.5, 1574 error_threshold_diff: float = 0.1, 1575 error_hysteresis: bool = False, 1576 suggestion_threshold: Union[float, List] = 0.5, 1577 suggestion_threshold_diff: Union[float, List] = 0.1, 1578 suggestion_hysteresis: Union[bool, List] = True, 1579 min_frames_suggestion: int = 10, 1580 min_frames_al: int = 30, 1581 visibility_min_score: float = 0, 1582 visibility_min_frac: float = 0.7, 1583 augment_n: int = 0, 1584 exclude_classes: List = None, 1585 exclude_threshold: Union[float, List] = 0.6, 1586 exclude_threshold_diff: Union[float, List] = 0.1, 1587 exclude_hysteresis: Union[bool, List] = False, 1588 include_classes: List = None, 1589 include_threshold: Union[float, List] = 0.4, 1590 include_threshold_diff: Union[float, List] = 0.1, 1591 include_hysteresis: Union[bool, List] = False, 1592 data_path: str = None, 1593 file_paths: Set = None, 1594 parameters_update: Dict = None, 1595 mode: str = "all", 1596 force: bool = False, 1597 remove_saved_features: bool = False, 1598 cut_annotated: bool = False, 1599 background_threshold: float = None, 1600 ) -> None: 1601 """Create active learning and suggestion files. 1602 1603 Generate predictions with the error and suggestion model and use them to create 1604 suggestion files for the labeling interface. Those files will render as suggested labels 1605 at intervals with high pose estimation quality. Quality here is defined by probability of error 1606 (predicted by the error model) and visibility parameters. 1607 1608 If `error_episode` or `exclude_classes` is not `None`, 1609 an active learning file will be created as well (with frames with high predicted probability of classes 1610 from `exclude_classes` and/or errors excluded from the active learning intervals). 1611 1612 In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals) 1613 you can apply one of three methods. 1614 1615 - **Simple threshold** 1616 1617 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold` 1618 parameter to $\alpha$. 1619 In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will 1620 be considered labeled. 1621 1622 - **Hysteresis threshold** 1623 1624 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold` 1625 parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$. 1626 Now intervals will be marked with a label if the probability of that label for all frames is higher 1627 than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$. 1628 1629 - **Max hysteresis threshold** 1630 1631 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold` 1632 parameter to $\alpha$ and the `threshold_diff` parameter to `None`. 1633 With this combination intervals are marked with a label if that label is more likely than any other 1634 for all frames in this interval and at for at least one of those frames its probability is higher than 1635 $\alpha$. 1636 1637 Parameters 1638 ---------- 1639 suggestions_name : str 1640 the name of the suggestions 1641 error_episode : str, optional 1642 the name of the episode where the error model should be loaded from 1643 error_load_epoch : int, optional 1644 the epoch the error model should be loaded from 1645 error_class : str, optional 1646 the name of the error class (in `error_episode`) 1647 suggestions_prediction : str, optional 1648 the name of the predictions that should be used for the suggestion model 1649 suggestion_episodes : list, optional 1650 the names of the episodes where the suggestion models should be loaded from 1651 suggestion_load_epoch : int, optional 1652 the epoch the suggestion model should be loaded from 1653 suggestion_classes : list, optional 1654 a list of string names of the classes that should be suggested (in `suggestion_episode`) 1655 error_threshold : float, default 0.5 1656 the hard threshold for error prediction 1657 error_threshold_diff : float, default 0.1 1658 the difference between soft and hard thresholds for error prediction (in case hysteresis is used) 1659 error_hysteresis : bool, default False 1660 if True, hysteresis is used for error prediction 1661 suggestion_threshold : float | list, default 0.5 1662 the hard threshold for class prediction (use a list to set different rules for different classes) 1663 suggestion_threshold_diff : float | list, default 0.1 1664 the difference between soft and hard thresholds for class prediction (in case hysteresis is used; 1665 use a list to set different rules for different classes) 1666 suggestion_hysteresis : bool | list, default True 1667 if True, hysteresis is used for class prediction (use a list to set different rules for different classes) 1668 min_frames_suggestion : int, default 10 1669 only actions longer than this number of frames will be suggested 1670 min_frames_al : int, default 30 1671 only active learning intervals longer than this number of frames will be suggested 1672 visibility_min_score : float, default 0 1673 the minimum visibility score for visibility filtering 1674 visibility_min_frac : float, default 0.7 1675 the minimum fraction of visible frames for visibility filtering 1676 augment_n : int, default 10 1677 the number of augmentations to average the predictions over 1678 exclude_classes : list, optional 1679 a list of string names of classes that should be excluded from the active learning intervals 1680 exclude_threshold : float | list, default 0.6 1681 the hard threshold for excluded class prediction (use a list to set different rules for different classes) 1682 exclude_threshold_diff : float | list, default 0.1 1683 the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used) 1684 exclude_hysteresis : bool | list, default False 1685 if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes) 1686 include_classes : list, optional 1687 a list of string names of classes that should be included into the active learning intervals 1688 include_threshold : float | list, default 0.6 1689 the hard threshold for included class prediction (use a list to set different rules for different classes) 1690 include_threshold_diff : float | list, default 0.1 1691 the difference between soft and hard thresholds for included class prediction (in case hysteresis is used) 1692 include_hysteresis : bool | list, default False 1693 if True, hysteresis is used for included class prediction (use a list to set different rules for different classes) 1694 data_path : str, optional 1695 the data path to run the prediction for 1696 file_paths : set, optional 1697 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1698 for 1699 parameters_update : dict, optional 1700 the parameters update dictionary 1701 mode : {'all', 'test', 'val', 'train'} 1702 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1703 force : bool, default False 1704 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 1705 remove_saved_features : bool, default False 1706 if `True`, the dataset will be deleted. 1707 cut_annotated : bool, default False 1708 if `True`, annotated frames will be cut from the suggestions 1709 background_threshold : float, default 0.5 1710 the threshold for background prediction 1711 1712 """ 1713 self._check_suggestions_validity(suggestions_name, force=force) 1714 if any([x is None for x in suggestion_episodes]): 1715 suggestion_episodes = None 1716 if error_episode is None and ( 1717 suggestion_episodes is None and suggestions_prediction is None 1718 ): 1719 raise ValueError( 1720 "Both error_episode and suggestion_episode parameters cannot be None at the same time" 1721 ) 1722 print(f"SUGGESTION {suggestions_name}") 1723 task = None 1724 if suggestion_classes is None: 1725 suggestion_classes = [] 1726 if exclude_classes is None: 1727 exclude_classes = [] 1728 if include_classes is None: 1729 include_classes = [] 1730 if isinstance(suggestion_threshold, list): 1731 if len(suggestion_threshold) != len(suggestion_classes): 1732 raise ValueError( 1733 "The suggestion_threshold parameter has to be either a float value or a list of " 1734 f"float values of the same length as suggestion_classes (got a list of length " 1735 f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)" 1736 ) 1737 else: 1738 suggestion_threshold = [suggestion_threshold for _ in suggestion_classes] 1739 if isinstance(suggestion_threshold_diff, list): 1740 if len(suggestion_threshold_diff) != len(suggestion_classes): 1741 raise ValueError( 1742 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1743 f"float values of the same length as suggestion_classes (got a list of length " 1744 f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)" 1745 ) 1746 else: 1747 suggestion_threshold_diff = [ 1748 suggestion_threshold_diff for _ in suggestion_classes 1749 ] 1750 if isinstance(suggestion_hysteresis, list): 1751 if len(suggestion_hysteresis) != len(suggestion_classes): 1752 raise ValueError( 1753 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1754 f"float values of the same length as suggestion_classes (got a list of length " 1755 f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)" 1756 ) 1757 else: 1758 suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes] 1759 if isinstance(exclude_threshold, list): 1760 if len(exclude_threshold) != len(exclude_classes): 1761 raise ValueError( 1762 "The exclude_threshold parameter has to be either a float value or a list of " 1763 f"float values of the same length as exclude_classes (got a list of length " 1764 f"{len(exclude_threshold)} for {len(exclude_classes)} classes)" 1765 ) 1766 else: 1767 exclude_threshold = [exclude_threshold for _ in exclude_classes] 1768 if isinstance(exclude_threshold_diff, list): 1769 if len(exclude_threshold_diff) != len(exclude_classes): 1770 raise ValueError( 1771 "The exclude_threshold_diff parameter has to be either a float value or a list of " 1772 f"float values of the same length as exclude_classes (got a list of length " 1773 f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)" 1774 ) 1775 else: 1776 exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes] 1777 if isinstance(exclude_hysteresis, list): 1778 if len(exclude_hysteresis) != len(exclude_classes): 1779 raise ValueError( 1780 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1781 f"float values of the same length as suggestion_classes (got a list of length " 1782 f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)" 1783 ) 1784 else: 1785 exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes] 1786 if isinstance(include_threshold, list): 1787 if len(include_threshold) != len(include_classes): 1788 raise ValueError( 1789 "The exclude_threshold parameter has to be either a float value or a list of " 1790 f"float values of the same length as exclude_classes (got a list of length " 1791 f"{len(include_threshold)} for {len(include_classes)} classes)" 1792 ) 1793 else: 1794 include_threshold = [include_threshold for _ in include_classes] 1795 if isinstance(include_threshold_diff, list): 1796 if len(include_threshold_diff) != len(include_classes): 1797 raise ValueError( 1798 "The exclude_threshold_diff parameter has to be either a float value or a list of " 1799 f"float values of the same length as exclude_classes (got a list of length " 1800 f"{len(include_threshold_diff)} for {len(include_classes)} classes)" 1801 ) 1802 else: 1803 include_threshold_diff = [include_threshold_diff for _ in include_classes] 1804 if isinstance(include_hysteresis, list): 1805 if len(include_hysteresis) != len(include_classes): 1806 raise ValueError( 1807 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1808 f"float values of the same length as suggestion_classes (got a list of length " 1809 f"{len(include_hysteresis)} for {len(include_classes)} classes)" 1810 ) 1811 else: 1812 include_hysteresis = [include_hysteresis for _ in include_classes] 1813 if (suggestion_episodes is None and suggestions_prediction is None) and len( 1814 exclude_classes 1815 ) > 0: 1816 raise ValueError( 1817 "In order to exclude classes from the active learning intervals you need to set the " 1818 "suggestion_episode parameter" 1819 ) 1820 1821 task = None 1822 if error_episode is not None: 1823 task, parameters, mode = self._make_task_prediction( 1824 prediction_name=suggestions_name, 1825 load_episode=error_episode, 1826 parameters_update=parameters_update, 1827 load_epoch=error_load_epoch, 1828 data_path=data_path, 1829 mode=mode, 1830 file_paths=file_paths, 1831 task=task, 1832 ) 1833 predicted_error = task.predict( 1834 data=mode, 1835 raw_output=True, 1836 apply_primary_function=True, 1837 augment_n=augment_n, 1838 ) 1839 else: 1840 predicted_error = None 1841 1842 if suggestion_episodes is not None: 1843 ( 1844 task, 1845 parameters, 1846 mode, 1847 predicted_classes, 1848 inference_time, 1849 behavior_dict, 1850 ) = self._make_prediction( 1851 prediction_name=suggestions_name, 1852 episode_names=suggestion_episodes, 1853 load_epochs=suggestion_load_epoch, 1854 parameters_update=parameters_update, 1855 data_path=data_path, 1856 file_paths=file_paths, 1857 mode=mode, 1858 task=task, 1859 ) 1860 elif suggestions_prediction is not None: 1861 with open( 1862 os.path.join( 1863 self.project_path, 1864 "results", 1865 "predictions", 1866 f"{suggestions_prediction}.pickle", 1867 ), 1868 "rb", 1869 ) as f: 1870 predicted_classes = pickle.load(f) 1871 if parameters_update is None: 1872 parameters_update = {} 1873 parameters_update = self._update( 1874 self._predictions().load_parameters(suggestions_prediction), 1875 parameters_update, 1876 ) 1877 parameters_update.pop("model") 1878 if suggestion_episodes is None: 1879 suggestion_episodes = [ 1880 os.path.basename( 1881 os.path.dirname( 1882 parameters_update["training"]["checkpoint_path"] 1883 ) 1884 ) 1885 ] 1886 task, parameters, mode = self._make_task_prediction( 1887 "_", 1888 load_episode=None, 1889 parameters_update=parameters_update, 1890 data_path=data_path, 1891 file_paths=file_paths, 1892 mode=mode, 1893 ) 1894 else: 1895 predicted_classes = None 1896 1897 if len(suggestion_classes) > 0 and predicted_classes is not None: 1898 suggestions = self._make_suggestions( 1899 task, 1900 predicted_error, 1901 predicted_classes, 1902 suggestion_threshold, 1903 suggestion_threshold_diff, 1904 suggestion_hysteresis, 1905 suggestion_episodes, 1906 suggestion_classes, 1907 error_threshold, 1908 min_frames_suggestion, 1909 min_frames_al, 1910 visibility_min_score, 1911 visibility_min_frac, 1912 cut_annotated=cut_annotated, 1913 ) 1914 videos = list(suggestions.keys()) 1915 for v_id in videos: 1916 times_dict = defaultdict(lambda: defaultdict(lambda: [])) 1917 clips = set() 1918 for c in suggestions[v_id]: 1919 for start, end, ind in suggestions[v_id][c]: 1920 times_dict[ind][c].append([start, end, 2]) 1921 clips.add(ind) 1922 clips = list(clips) 1923 times_dict = dict(times_dict) 1924 times = [ 1925 [times_dict[ind][c] for c in suggestion_classes] for ind in clips 1926 ] 1927 save_path = self._suggestion_path(v_id, suggestions_name) 1928 Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True) 1929 with open(save_path, "wb") as f: 1930 pickle.dump((None, suggestion_classes, clips, times), f) 1931 1932 if ( 1933 error_episode is not None 1934 or len(exclude_classes) > 0 1935 or len(include_classes) > 0 1936 ): 1937 al_points = self._make_al_points( 1938 task, 1939 predicted_error, 1940 predicted_classes, 1941 exclude_classes, 1942 exclude_threshold, 1943 exclude_threshold_diff, 1944 exclude_hysteresis, 1945 include_classes, 1946 include_threshold, 1947 include_threshold_diff, 1948 include_hysteresis, 1949 error_episode, 1950 error_class, 1951 suggestion_episodes, 1952 error_threshold, 1953 error_threshold_diff, 1954 error_hysteresis, 1955 min_frames_al, 1956 visibility_min_score, 1957 visibility_min_frac, 1958 ) 1959 else: 1960 al_points = self._make_al_points_from_suggestions( 1961 suggestions_name, 1962 task, 1963 predicted_classes, 1964 background_threshold, 1965 visibility_min_score, 1966 visibility_min_frac, 1967 num_behaviors=len(task.behaviors_dict()), 1968 ) 1969 save_path = self._al_points_path(suggestions_name) 1970 Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True) 1971 with open(save_path, "wb") as f: 1972 pickle.dump(al_points, f) 1973 1974 meta_parameters = { 1975 "error_episode": error_episode, 1976 "error_load_epoch": error_load_epoch, 1977 "error_class": error_class, 1978 "suggestion_episode": suggestion_episodes, 1979 "suggestion_load_epoch": suggestion_load_epoch, 1980 "suggestion_classes": suggestion_classes, 1981 "error_threshold": error_threshold, 1982 "error_threshold_diff": error_threshold_diff, 1983 "error_hysteresis": error_hysteresis, 1984 "suggestion_threshold": suggestion_threshold, 1985 "suggestion_threshold_diff": suggestion_threshold_diff, 1986 "suggestion_hysteresis": suggestion_hysteresis, 1987 "min_frames_suggestion": min_frames_suggestion, 1988 "min_frames_al": min_frames_al, 1989 "visibility_min_score": visibility_min_score, 1990 "visibility_min_frac": visibility_min_frac, 1991 "augment_n": augment_n, 1992 "exclude_classes": exclude_classes, 1993 "exclude_threshold": exclude_threshold, 1994 "exclude_threshold_diff": exclude_threshold_diff, 1995 "exclude_hysteresis": exclude_hysteresis, 1996 } 1997 self._save_suggestions(suggestions_name, {}, meta_parameters) 1998 if data_path is not None or file_paths is not None or remove_saved_features: 1999 self._remove_stores(parameters) 2000 print(f"\n") 2001 2002 def _generate_similarity_score( 2003 self, 2004 prediction_name: str, 2005 target_video_id: str, 2006 target_clip: str, 2007 target_start: int, 2008 target_end: int, 2009 ) -> Dict: 2010 with open( 2011 os.path.join( 2012 self.project_path, 2013 "results", 2014 "predictions", 2015 f"{prediction_name}.pickle", 2016 ), 2017 "rb", 2018 ) as f: 2019 prediction = pickle.load(f) 2020 target = prediction[target_video_id][target_clip][:, target_start:target_end] 2021 score_dict = defaultdict(lambda: {}) 2022 for video_id in prediction: 2023 for clip_id in prediction[video_id]: 2024 score_dict[video_id][clip_id] = torch.cdist( 2025 target.T, prediction[video_id][score_dict].T 2026 ).min(0) 2027 return score_dict 2028 2029 def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict: 2030 """Suggest intervals from a score dictionary. 2031 2032 Parameters 2033 ---------- 2034 score_dict : dict 2035 a dictionary containing scores for intervals 2036 min_length : int 2037 minimum length of intervals to suggest 2038 n_intervals : int 2039 number of intervals to suggest 2040 2041 Returns 2042 ------- 2043 intervals : dict 2044 a dictionary of suggested intervals 2045 2046 """ 2047 interval_address = {} 2048 interval_value = {} 2049 s = 0 2050 n = 0 2051 for video_id, video_dict in score_dict.items(): 2052 for clip_id, value in video_dict.items(): 2053 s += value.mean() 2054 n += 1 2055 mean_value = s / n 2056 alpha = 1.75 2057 for it in range(10): 2058 id = 0 2059 interval_address = {} 2060 interval_value = {} 2061 for video_id, video_dict in score_dict.items(): 2062 for clip_id, value in video_dict.items(): 2063 res_indices_start, res_indices_end = apply_threshold( 2064 value, 2065 threshold=(2 - alpha * (0.9**it)) * mean_value, 2066 low=True, 2067 error_mask=None, 2068 min_frames=min_length, 2069 smooth_interval=0, 2070 ) 2071 for start, end in zip(res_indices_start, res_indices_end): 2072 interval_address[id] = [video_id, clip_id, start, end] 2073 interval_value[id] = score_dict[video_id][clip_id][ 2074 start:end 2075 ].mean() 2076 id += 1 2077 if len(interval_address) >= n_intervals: 2078 break 2079 if len(interval_address) < n_intervals: 2080 warnings.warn( 2081 f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals" 2082 ) 2083 sorted_intervals = sorted( 2084 interval_value.items(), key=lambda x: x[1], reverse=True 2085 ) 2086 output_intervals = [ 2087 interval_address[x[0]] 2088 for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)] 2089 ] 2090 output = defaultdict(lambda: []) 2091 for video_id, clip_id, start, end in output_intervals: 2092 output[video_id].append([start, end, clip_id]) 2093 return output 2094 2095 def suggest_intervals_with_similarity( 2096 self, 2097 suggestions_name: str, 2098 prediction_name: str, 2099 target_video_id: str, 2100 target_clip: str, 2101 target_start: int, 2102 target_end: int, 2103 min_length: int = 60, 2104 n_intervals: int = 5, 2105 force: bool = False, 2106 ): 2107 """ 2108 Suggest intervals based on similarity to a target interval. 2109 2110 Parameters 2111 ---------- 2112 suggestions_name : str 2113 Name of the suggestion. 2114 prediction_name : str 2115 Name of the prediction to use. 2116 target_video_id : str 2117 Video id of the target interval. 2118 target_clip : str 2119 Clip id of the target interval. 2120 target_start : int 2121 Start frame of the target interval. 2122 target_end : int 2123 End frame of the target interval. 2124 min_length : int, default 60 2125 Minimum length of the suggested intervals. 2126 n_intervals : int, default 5 2127 Number of suggested intervals. 2128 force : bool, default False 2129 If True, the suggestion is overwritten if it already exists. 2130 2131 """ 2132 self._check_suggestions_validity(suggestions_name, force=force) 2133 print(f"SUGGESTION {suggestions_name}") 2134 score_dict = self._generate_similarity_score( 2135 prediction_name, target_video_id, target_clip, target_start, target_end 2136 ) 2137 intervals = self._suggest_intervals_from_dict( 2138 score_dict, min_length, n_intervals 2139 ) 2140 suggestions_path = os.path.join( 2141 self.project_path, 2142 "results", 2143 "suggestions", 2144 suggestions_name, 2145 ) 2146 if not os.path.exists(suggestions_path): 2147 os.mkdir(suggestions_path) 2148 with open( 2149 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2150 ) as f: 2151 pickle.dump(intervals, f) 2152 meta_parameters = { 2153 "prediction_name": prediction_name, 2154 "min_frames_suggestion": min_length, 2155 "n_intervals": n_intervals, 2156 "target_clip": target_clip, 2157 "target_start": target_start, 2158 "target_end": target_end, 2159 } 2160 self._save_suggestions(suggestions_name, {}, meta_parameters) 2161 print("\n") 2162 2163 def suggest_intervals_with_uncertainty( 2164 self, 2165 suggestions_name: str, 2166 episode_names: List, 2167 load_epochs: List = None, 2168 classes: List = None, 2169 n_frames: int = 10000, 2170 method: str = "least_confidence", 2171 min_length: int = 60, 2172 augment_n: int = 0, 2173 data_path: str = None, 2174 file_paths: Set = None, 2175 parameters_update: Dict = None, 2176 mode: str = "all", 2177 force: bool = False, 2178 remove_saved_features: bool = False, 2179 ) -> None: 2180 """Generate an active learning file based on model uncertainty. 2181 2182 If you provide several episode names, the predicted probabilities will be averaged. 2183 2184 Parameters 2185 ---------- 2186 suggestions_name : str 2187 the name of the suggestion 2188 episode_names : list 2189 a list of string episode names to load the models from 2190 load_epochs : list, optional 2191 a list of epoch indices to load the models from (if `None`, the last ones will be used) 2192 classes : list, optional 2193 a list of classes to look at (by default all) 2194 n_frames : int, default 10000 2195 the threshold total number of frames in the suggested intervals (in the end result it will most likely 2196 be slightly larger; it will only be smaller if the algorithm fails to find enough intervals 2197 with the set parameters) 2198 method : {"least_confidence", "entropy"} 2199 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 2200 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 2201 min_length : int, default 60 2202 the minimum number of frames in one interval 2203 augment_n : int, default 0 2204 the number of augmentations to average the predictions over 2205 data_path : str, optional 2206 the path to a data folder (by default, the project data is used) 2207 file_paths : set, optional 2208 a list of file paths (by default, the project data is used) 2209 parameters_update : dict, optional 2210 a dictionary of parameter updates 2211 mode : {"test", "val", "train", "all"} 2212 the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`; 2213 by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise) 2214 force : bool, default False 2215 if `True`, existing suggestions with the same name will be overwritten 2216 remove_saved_features : bool, default False 2217 if `True`, the dataset will be deleted after the computation 2218 2219 """ 2220 self._check_suggestions_validity(suggestions_name, force=force) 2221 print(f"SUGGESTION {suggestions_name}") 2222 task, parameters, mode, predicted, inference_time, behavior_dict = ( 2223 self._make_prediction( 2224 suggestions_name, 2225 episode_names, 2226 load_epochs, 2227 parameters_update, 2228 data_path=data_path, 2229 file_paths=file_paths, 2230 mode=mode, 2231 augment_n=augment_n, 2232 evaluate=False, 2233 ) 2234 ) 2235 if classes is None: 2236 classes = behavior_dict.values() 2237 episode = self._episodes().get_runs(episode_names[0])[0] 2238 score_tensors = task.generate_uncertainty_score( 2239 classes, 2240 augment_n, 2241 method, 2242 predicted, 2243 self._episode(episode).get_behaviors_dict(), 2244 ) 2245 intervals = self._suggest_intervals( 2246 task.dataset(mode), score_tensors, n_frames, min_length 2247 ) 2248 for k, v in intervals.items(): 2249 l = sum([x[1] - x[0] for x in v]) 2250 print(f"{k}: {len(v)} ({l})") 2251 if remove_saved_features: 2252 self._remove_stores(parameters) 2253 suggestions_path = os.path.join( 2254 self.project_path, 2255 "results", 2256 "suggestions", 2257 suggestions_name, 2258 ) 2259 if not os.path.exists(suggestions_path): 2260 os.mkdir(suggestions_path) 2261 with open( 2262 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2263 ) as f: 2264 pickle.dump(intervals, f) 2265 meta_parameters = { 2266 "suggestion_episode": episode_names, 2267 "suggestion_load_epoch": load_epochs, 2268 "suggestion_classes": classes, 2269 "min_frames_suggestion": min_length, 2270 "augment_n": augment_n, 2271 "method": method, 2272 "num_frames": n_frames, 2273 } 2274 self._save_suggestions(suggestions_name, {}, meta_parameters) 2275 print("\n") 2276 2277 def suggest_intervals_with_bald( 2278 self, 2279 suggestions_name: str, 2280 episode_name: str, 2281 load_epoch: int = None, 2282 classes: List = None, 2283 n_frames: int = 10000, 2284 num_models: int = 10, 2285 kernel_size: int = 11, 2286 min_length: int = 60, 2287 augment_n: int = 0, 2288 data_path: str = None, 2289 file_paths: Set = None, 2290 parameters_update: Dict = None, 2291 mode: str = "all", 2292 force: bool = False, 2293 remove_saved_features: bool = False, 2294 ): 2295 """Generate an active learning file based on Bayesian Active Learning by Disagreement. 2296 2297 Parameters 2298 ---------- 2299 suggestions_name : str 2300 the name of the suggestion 2301 episode_name : str 2302 the name of the episode to load the model from 2303 load_epoch : int, optional 2304 the index of the epoch to load the model from (if `None`, the last one will be used) 2305 classes : list, optional 2306 a list of classes to look at (by default all) 2307 n_frames : int, default 10000 2308 the threshold total number of frames in the suggested intervals (in the end result it will most likely 2309 be slightly larger; it will only be smaller if the algorithm fails to find enough intervals 2310 with the set parameters) 2311 num_models : int, default 10 2312 the number of dropout masks to apply 2313 kernel_size : int, default 11 2314 the size of the smoothing kernel applied to the discrete results 2315 min_length : int, default 60 2316 the minimum number of frames in one interval 2317 augment_n : int, default 0 2318 the number of augmentations to average the predictions over 2319 data_path : str, optional 2320 the path to a data folder (by default, the project data is used) 2321 file_paths : set, optional 2322 a list of file paths (by default, the project data is used) 2323 parameters_update : dict, optional 2324 a dictionary of parameter updates 2325 mode : {"test", "val", "train", "all"} 2326 the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`; 2327 by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise) 2328 force : bool, default False 2329 if `True`, existing suggestions with the same name will be overwritten 2330 remove_saved_features : bool, default False 2331 if `True`, the dataset will be deleted after the computation 2332 2333 """ 2334 self._check_suggestions_validity(suggestions_name, force=force) 2335 print(f"SUGGESTION {suggestions_name}") 2336 task, parameters, mode = self._make_task_prediction( 2337 suggestions_name, 2338 episode_name, 2339 parameters_update, 2340 load_epoch, 2341 data_path=data_path, 2342 file_paths=file_paths, 2343 mode=mode, 2344 ) 2345 if classes is None: 2346 classes = list(task.behaviors_dict().values()) 2347 score_tensors = task.generate_bald_score( 2348 classes, augment_n, num_models, kernel_size 2349 ) 2350 intervals = self._suggest_intervals( 2351 task.dataset(mode), score_tensors, n_frames, min_length 2352 ) 2353 if remove_saved_features: 2354 self._remove_stores(parameters) 2355 suggestions_path = os.path.join( 2356 self.project_path, 2357 "results", 2358 "suggestions", 2359 suggestions_name, 2360 ) 2361 if not os.path.exists(suggestions_path): 2362 os.mkdir(suggestions_path) 2363 with open( 2364 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2365 ) as f: 2366 pickle.dump(intervals, f) 2367 meta_parameters = { 2368 "suggestion_episode": episode_name, 2369 "suggestion_load_epoch": load_epoch, 2370 "suggestion_classes": classes, 2371 "min_frames_suggestion": min_length, 2372 "augment_n": augment_n, 2373 "method": f"BALD:{num_models}", 2374 "num_frames": n_frames, 2375 } 2376 self._save_suggestions(suggestions_name, {}, meta_parameters) 2377 print("\n") 2378 2379 def list_episodes( 2380 self, 2381 episode_names: List = None, 2382 value_filter: str = "", 2383 display_parameters: List = None, 2384 print_results: bool = True, 2385 ) -> pd.DataFrame: 2386 """Get a filtered pandas dataframe with episode metadata. 2387 2388 Parameters 2389 ---------- 2390 episode_names : list 2391 a list of strings of episode names 2392 value_filter : str 2393 a string of filters to apply; of this general structure: 2394 'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g. 2395 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' 2396 display_parameters : list 2397 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2398 print_results : bool, default True 2399 if True, the result will be printed to standard output 2400 2401 Returns 2402 ------- 2403 pd.DataFrame 2404 the filtered dataframe 2405 2406 """ 2407 episodes = self._episodes().list_episodes( 2408 episode_names, value_filter, display_parameters 2409 ) 2410 if print_results: 2411 print("TRAINING EPISODES") 2412 print(episodes) 2413 print("\n") 2414 return episodes 2415 2416 def list_predictions( 2417 self, 2418 episode_names: List = None, 2419 value_filter: str = "", 2420 display_parameters: List = None, 2421 print_results: bool = True, 2422 ) -> pd.DataFrame: 2423 """Get a filtered pandas dataframe with prediction metadata. 2424 2425 Parameters 2426 ---------- 2427 episode_names : list 2428 a list of strings of episode names 2429 value_filter : str 2430 a string of filters to apply; of this general structure: 2431 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2432 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2433 display_parameters : list 2434 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2435 print_results : bool, default True 2436 if True, the result will be printed to standard output 2437 2438 Returns 2439 ------- 2440 pd.DataFrame 2441 the filtered dataframe 2442 2443 """ 2444 predictions = self._predictions().list_episodes( 2445 episode_names, value_filter, display_parameters 2446 ) 2447 if print_results: 2448 print("PREDICTIONS") 2449 print(predictions) 2450 print("\n") 2451 return predictions 2452 2453 def list_suggestions( 2454 self, 2455 suggestions_names: List = None, 2456 value_filter: str = "", 2457 display_parameters: List = None, 2458 print_results: bool = True, 2459 ) -> pd.DataFrame: 2460 """Get a filtered pandas dataframe with prediction metadata. 2461 2462 Parameters 2463 ---------- 2464 suggestions_names : list 2465 a list of strings of suggestion names 2466 value_filter : str 2467 a string of filters to apply; of this general structure: 2468 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2469 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2470 display_parameters : list 2471 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2472 print_results : bool, default True 2473 if True, the result will be printed to standard output 2474 2475 Returns 2476 ------- 2477 pd.DataFrame 2478 the filtered dataframe 2479 2480 """ 2481 suggestions = self._suggestions().list_episodes( 2482 suggestions_names, value_filter, display_parameters 2483 ) 2484 if print_results: 2485 print("SUGGESTIONS") 2486 print(suggestions) 2487 print("\n") 2488 return suggestions 2489 2490 def list_searches( 2491 self, 2492 search_names: List = None, 2493 value_filter: str = "", 2494 display_parameters: List = None, 2495 print_results: bool = True, 2496 ) -> pd.DataFrame: 2497 """Get a filtered pandas dataframe with hyperparameter search metadata. 2498 2499 Parameters 2500 ---------- 2501 search_names : list 2502 a list of strings of search names 2503 value_filter : str 2504 a string of filters to apply; of this general structure: 2505 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2506 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2507 display_parameters : list 2508 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2509 print_results : bool, default True 2510 if True, the result will be printed to standard output 2511 2512 Returns 2513 ------- 2514 pd.DataFrame 2515 the filtered dataframe 2516 2517 """ 2518 searches = self._searches().list_episodes( 2519 search_names, value_filter, display_parameters 2520 ) 2521 if print_results: 2522 print("SEARCHES") 2523 print(searches) 2524 print("\n") 2525 return searches 2526 2527 def get_best_parameters( 2528 self, 2529 search_name: str, 2530 round_to_binary: List = None, 2531 ): 2532 """Get the best parameters found by a search. 2533 2534 Parameters 2535 ---------- 2536 search_name : str 2537 the name of the search 2538 round_to_binary : list, default None 2539 a list of parameters to round to binary values 2540 2541 Returns 2542 ------- 2543 best_params : dict 2544 a dictionary of the best parameters where the keys are in '{group}/{name}' format 2545 2546 """ 2547 params, model = self._searches().get_best_params( 2548 search_name, round_to_binary=round_to_binary 2549 ) 2550 params = self._update(params, {"general": {"model_name": model}}) 2551 return params 2552 2553 def list_best_parameters( 2554 self, search_name: str, print_results: bool = True 2555 ) -> Dict: 2556 """Get the raw dictionary of best parameters found by a search. 2557 2558 Parameters 2559 ---------- 2560 search_name : str 2561 the name of the search 2562 print_results : bool, default True 2563 if True, the result will be printed to standard output 2564 2565 Returns 2566 ------- 2567 best_params : dict 2568 a dictionary of the best parameters where the keys are in '{group}/{name}' format 2569 2570 """ 2571 params = self._searches().get_best_params_raw(search_name) 2572 if print_results: 2573 print(f"SEARCH RESULTS {search_name}") 2574 for k, v in params.items(): 2575 print(f"{k}: {v}") 2576 print("\n") 2577 return params 2578 2579 def plot_episodes( 2580 self, 2581 episode_names: List, 2582 metrics: List | str, 2583 modes: List | str = None, 2584 title: str = None, 2585 episode_labels: List = None, 2586 save_path: str = None, 2587 add_hlines: List = None, 2588 epoch_limits: List = None, 2589 colors: List = None, 2590 add_highpoint_hlines: bool = False, 2591 remove_box: bool = False, 2592 font_size: float = None, 2593 linewidth: float = None, 2594 return_ax: bool = False, 2595 ) -> None: 2596 """Plot episode training curves. 2597 2598 Parameters 2599 ---------- 2600 episode_names : list 2601 a list of episode names to plot; to plot to episodes in one line combine them in a list 2602 (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) 2603 metrics : list 2604 a list of metric to plot 2605 modes : list, optional 2606 a list of modes to plot ('train' and/or 'val'; `['val']` by default) 2607 title : str, optional 2608 title for the plot 2609 episode_labels : list, optional 2610 a list of strings used to label the curves (has to be the same length as episode_names) 2611 save_path : str, optional 2612 the path to save the resulting plot 2613 add_hlines : list, optional 2614 a list of float values (or (value, label) tuples) to mark with horizontal lines 2615 epoch_limits : list, optional 2616 a list of (min, max) tuples to set the x-axis limits for each episode 2617 colors: list, optional 2618 a list of matplotlib colors 2619 add_highpoint_hlines : bool, default False 2620 if `True`, horizontal lines will be added at the highest value of each episode 2621 """ 2622 2623 if isinstance(metrics, str): 2624 metrics = [metrics] 2625 if isinstance(modes, str): 2626 modes = [modes] 2627 2628 if font_size is not None: 2629 font = {"size": font_size} 2630 rc("font", **font) 2631 if modes is None: 2632 modes = ["val"] 2633 if add_hlines is None: 2634 add_hlines = [] 2635 logs = [] 2636 epochs = [] 2637 labels = [] 2638 if episode_labels is not None: 2639 assert len(episode_labels) == len(episode_names) 2640 for name_i, name in enumerate(episode_names): 2641 log_params = product(metrics, modes) 2642 for metric, mode in log_params: 2643 if episode_labels is not None: 2644 label = episode_labels[name_i] 2645 else: 2646 label = deepcopy(name) 2647 if len(modes) != 1: 2648 label += f"_{mode}" 2649 if len(metrics) != 1: 2650 label += f"_{metric}" 2651 labels.append(label) 2652 if isinstance(name, Iterable) and not isinstance(name, str): 2653 epoch_list = defaultdict(lambda: []) 2654 multi_logs = defaultdict(lambda: []) 2655 for i, n in enumerate(name): 2656 runs = self._episodes().get_runs(n) 2657 if len(runs) > 1: 2658 for run in runs: 2659 if "::" in run: 2660 index = run.split("::")[-1] 2661 else: 2662 index = run.split("#")[-1] 2663 if multi_logs[index] == []: 2664 if multi_logs["null"] is None: 2665 raise RuntimeError( 2666 "The run indices are not consistent across episodes!" 2667 ) 2668 else: 2669 multi_logs[index] += multi_logs["null"] 2670 multi_logs[index] += list( 2671 self._episode(run).get_metric_log(mode, metric) 2672 ) 2673 start = ( 2674 0 2675 if len(epoch_list[index]) == 0 2676 else epoch_list[index][-1] 2677 ) 2678 epoch_list[index] += [ 2679 x + start 2680 for x in self._episode(run).get_epoch_list(mode) 2681 ] 2682 multi_logs["null"] = None 2683 else: 2684 if len(multi_logs.keys()) > 1: 2685 raise RuntimeError( 2686 "Cannot plot a single-run episode after a multi-run episode!" 2687 ) 2688 multi_logs["null"] += list( 2689 self._episode(n).get_metric_log(mode, metric) 2690 ) 2691 start = ( 2692 0 2693 if len(epoch_list["null"]) == 0 2694 else epoch_list["null"][-1] 2695 ) 2696 epoch_list["null"] += [ 2697 x + start for x in self._episode(n).get_epoch_list(mode) 2698 ] 2699 if len(multi_logs.keys()) == 1: 2700 log = multi_logs["null"] 2701 epochs.append(epoch_list["null"]) 2702 else: 2703 log = tuple([v for k, v in multi_logs.items() if k != "null"]) 2704 epochs.append( 2705 tuple([v for k, v in epoch_list.items() if k != "null"]) 2706 ) 2707 else: 2708 runs = self._episodes().get_runs(name) 2709 if len(runs) > 1: 2710 log = [] 2711 for run in runs: 2712 tracked_metrics = self._episode(run).get_metrics() 2713 if metric in tracked_metrics: 2714 log.append( 2715 list( 2716 self._episode(run).get_metric_log(mode, metric) 2717 ) 2718 ) 2719 else: 2720 relevant = [] 2721 for m in tracked_metrics: 2722 m_split = m.split("_") 2723 if ( 2724 "_".join(m_split[:-1]) == metric 2725 and m_split[-1].isnumeric() 2726 ): 2727 relevant.append(m) 2728 if len(relevant) == 0: 2729 raise ValueError( 2730 f"The {metric} metric was not tracked at {run}" 2731 ) 2732 arr = 0 2733 for m in relevant: 2734 arr += self._episode(run).get_metric_log(mode, m) 2735 arr /= len(relevant) 2736 log.append(list(arr)) 2737 log = tuple(log) 2738 epochs.append( 2739 tuple( 2740 [ 2741 self._episode(run).get_epoch_list(mode) 2742 for run in runs 2743 ] 2744 ) 2745 ) 2746 else: 2747 tracked_metrics = self._episode(name).get_metrics() 2748 if metric in tracked_metrics: 2749 log = list(self._episode(name).get_metric_log(mode, metric)) 2750 else: 2751 relevant = [] 2752 for m in tracked_metrics: 2753 m_split = m.split("_") 2754 if ( 2755 "_".join(m_split[:-1]) == metric 2756 and m_split[-1].isnumeric() 2757 ): 2758 relevant.append(m) 2759 if len(relevant) == 0: 2760 raise ValueError( 2761 f"The {metric} metric was not tracked at {name}" 2762 ) 2763 arr = 0 2764 for m in relevant: 2765 arr += self._episode(name).get_metric_log(mode, m) 2766 arr /= len(relevant) 2767 log = list(arr) 2768 epochs.append(self._episode(name).get_epoch_list(mode)) 2769 logs.append(log) 2770 # if episode_labels is not None: 2771 # print(f'{len(episode_labels)=}, {len(logs)=}') 2772 # if len(episode_labels) != len(logs): 2773 2774 # raise ValueError( 2775 # f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of " 2776 # f"curves ({len(logs)})!" 2777 # ) 2778 # else: 2779 # labels = episode_labels 2780 if colors is None: 2781 colors = cm.rainbow(np.linspace(0, 1, len(logs))) 2782 if len(colors) != len(logs): 2783 raise ValueError( 2784 "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!" 2785 ) 2786 f, ax = plt.subplots() 2787 length = 0 2788 for log, label, color, epoch_list in zip(logs, labels, colors, epochs): 2789 if type(log) is list: 2790 if len(log) > length: 2791 length = len(log) 2792 ax.plot( 2793 epoch_list, 2794 log, 2795 label=label, 2796 color=color, 2797 ) 2798 if add_highpoint_hlines: 2799 ax.axhline(np.max(log), linestyle="dashed", color=color) 2800 else: 2801 for l, xx in zip(log, epoch_list): 2802 if len(l) > length: 2803 length = len(l) 2804 ax.plot( 2805 xx, 2806 l, 2807 color=color, 2808 alpha=0.2, 2809 ) 2810 if not all([len(x) == len(log[0]) for x in log]): 2811 warnings.warn( 2812 f"Got logs with unequal lengths in parallel runs for {label}" 2813 ) 2814 log = list(log) 2815 epoch_list = list(epoch_list) 2816 for i, x in enumerate(epoch_list): 2817 to_remove = [] 2818 for j, y in enumerate(x[1:]): 2819 if y <= x[j - 1]: 2820 y_ind = x.index(y) 2821 to_remove += list(range(y_ind, j)) 2822 epoch_list[i] = [ 2823 y for j, y in enumerate(x) if j not in to_remove 2824 ] 2825 log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove] 2826 length = min([len(x) for x in log]) 2827 for i in range(len(log)): 2828 log[i] = log[i][:length] 2829 epoch_list[i] = epoch_list[i][:length] 2830 if not all([x == epoch_list[0] for x in epoch_list]): 2831 raise RuntimeError( 2832 f"Got different epoch indices in parallel runs for {label}" 2833 ) 2834 mean = np.array(log).mean(0) 2835 ax.plot( 2836 epoch_list[0], 2837 mean, 2838 label=label, 2839 color=color, 2840 linewidth=linewidth, 2841 ) 2842 if add_highpoint_hlines: 2843 ax.axhline(np.max(mean), linestyle="dashed", color=color) 2844 for x in add_hlines: 2845 label = None 2846 if isinstance(x, Iterable): 2847 x, label = x 2848 ax.axhline(x, label=label) 2849 ax.set_xlim((0, length)) 2850 2851 ax.legend() 2852 ax.set_xlabel("epochs") 2853 if len(metrics) == 1: 2854 ax.set_ylabel(metrics[0]) 2855 else: 2856 ax.set_ylabel("value") 2857 if title is None: 2858 if len(episode_names) == 1: 2859 title = episode_names[0] 2860 elif len(metrics) == 1: 2861 title = metrics[0] 2862 if epoch_limits is not None: 2863 ax.set_xlim(epoch_limits) 2864 if title is not None: 2865 ax.set_title(title) 2866 if remove_box: 2867 ax.box(False) 2868 if return_ax: 2869 return ax 2870 if save_path is not None: 2871 plt.savefig(save_path) 2872 plt.show() 2873 2874 def update_parameters( 2875 self, 2876 parameters_update: Dict = None, 2877 load_search: str = None, 2878 load_parameters: List = None, 2879 round_to_binary: List = None, 2880 ) -> None: 2881 """Update the parameters in the project config files. 2882 2883 Parameters 2884 ---------- 2885 parameters_update : dict, optional 2886 a dictionary of parameter updates 2887 load_search : str, optional 2888 the name of hyperparameter search results to load to config 2889 load_parameters : list, optional 2890 a list of lists of string names of the parameters to load from the searches 2891 round_to_binary : list, optional 2892 a list of string names of the loaded parameters that should be rounded to the nearest power of two 2893 2894 """ 2895 keys = [ 2896 "general", 2897 "losses", 2898 "metrics", 2899 "ssl", 2900 "training", 2901 "data", 2902 ] 2903 parameters = self._read_parameters(catch_blanks=False) 2904 if parameters_update is not None: 2905 model_params = ( 2906 parameters_update.pop("model") if "model" in parameters_update else None 2907 ) 2908 feat_params = ( 2909 parameters_update.pop("features") 2910 if "features" in parameters_update 2911 else None 2912 ) 2913 aug_params = ( 2914 parameters_update.pop("augmentations") 2915 if "augmentations" in parameters_update 2916 else None 2917 ) 2918 2919 parameters = self._update(parameters, parameters_update) 2920 model_name = parameters["general"]["model_name"] 2921 parameters["model"] = self._open_yaml( 2922 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2923 ) 2924 if model_params is not None: 2925 parameters["model"] = self._update(parameters["model"], model_params) 2926 feat_name = parameters["general"]["feature_extraction"] 2927 parameters["features"] = self._open_yaml( 2928 os.path.join( 2929 self.project_path, "config", "features", f"{feat_name}.yaml" 2930 ) 2931 ) 2932 if feat_params is not None: 2933 parameters["features"] = self._update( 2934 parameters["features"], feat_params 2935 ) 2936 aug_name = options.extractor_to_transformer[ 2937 parameters["general"]["feature_extraction"] 2938 ] 2939 parameters["augmentations"] = self._open_yaml( 2940 os.path.join( 2941 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2942 ) 2943 ) 2944 if aug_params is not None: 2945 parameters["augmentations"] = self._update( 2946 parameters["augmentations"], aug_params 2947 ) 2948 if load_search is not None: 2949 parameters_update, model_name = self._searches().get_best_params( 2950 load_search, load_parameters, round_to_binary 2951 ) 2952 parameters["general"]["model_name"] = model_name 2953 parameters["model"] = self._open_yaml( 2954 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2955 ) 2956 parameters = self._update(parameters, parameters_update) 2957 for key in keys: 2958 with open( 2959 os.path.join(self.project_path, "config", f"{key}.yaml"), 2960 "w", 2961 encoding="utf-8", 2962 ) as f: 2963 YAML().dump(parameters[key], f) 2964 model_name = parameters["general"]["model_name"] 2965 model_path = os.path.join( 2966 self.project_path, "config", "model", f"{model_name}.yaml" 2967 ) 2968 with open(model_path, "w", encoding="utf-8") as f: 2969 YAML().dump(parameters["model"], f) 2970 features_name = parameters["general"]["feature_extraction"] 2971 features_path = os.path.join( 2972 self.project_path, "config", "features", f"{features_name}.yaml" 2973 ) 2974 with open(features_path, "w", encoding="utf-8") as f: 2975 YAML().dump(parameters["features"], f) 2976 aug_name = options.extractor_to_transformer[features_name] 2977 aug_path = os.path.join( 2978 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2979 ) 2980 with open(aug_path, "w", encoding="utf-8") as f: 2981 YAML().dump(parameters["augmentations"], f) 2982 2983 def get_summary( 2984 self, 2985 episode_names: list, 2986 method: str = "last", 2987 average: int = 1, 2988 metrics: List = None, 2989 return_values: bool = False, 2990 ) -> Dict: 2991 """Get a summary of episode statistics. 2992 2993 If an episode has multiple runs, the statistics will be aggregated over all of them. 2994 2995 Parameters 2996 ---------- 2997 episode_names : str 2998 the names of the episodes 2999 method : ["best", "last"] 3000 the method for choosing the epochs 3001 average : int, default 1 3002 the number of epochs to average over (for each run) 3003 metrics : list, optional 3004 a list of metrics 3005 3006 Returns 3007 ------- 3008 statistics : dict 3009 a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean 3010 and 'std' for the standard deviation 3011 3012 """ 3013 runs = [] 3014 for episode_name in episode_names: 3015 runs_ep = self._episodes().get_runs(episode_name) 3016 if len(runs_ep) == 0: 3017 raise RuntimeError( 3018 f"There is no {episode_name} episode in the project memory" 3019 ) 3020 runs += runs_ep 3021 if metrics is None: 3022 metrics = self._episode(runs[0]).get_metrics() 3023 3024 values = {m: [] for m in metrics} 3025 for run in runs: 3026 for m in metrics: 3027 log = self._episode(run).get_metric_log(mode="val", metric_name=m) 3028 if method == "best": 3029 log = sorted(log) 3030 values[m] += list(log[-average:]) 3031 elif method == "last": 3032 if len(log) == 0: 3033 episodes = self._episodes().data 3034 if average == 1 and ("results", m) in episodes.columns: 3035 values[m] += [episodes.loc[run, ("results", m)]] 3036 else: 3037 raise RuntimeError(f"Did not find {m} metric for {run} run") 3038 values[m] += list(log[-average:]) 3039 elif method.startswith("epoch"): 3040 epoch = int(method[5:]) - 1 3041 pars = self._episodes().load_parameters(run) 3042 step = int(pars["training"]["validation_interval"]) 3043 values[m] += [log[epoch // step]] 3044 else: 3045 raise ValueError( 3046 f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']" 3047 ) 3048 statistics = defaultdict(lambda: {}) 3049 for m, v in values.items(): 3050 statistics[m]["mean"] = np.mean(v) 3051 statistics[m]["std"] = np.std(v) 3052 print(f"SUMMARY {episode_names}") 3053 for m, v in statistics.items(): 3054 print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}') 3055 print("\n") 3056 3057 return (dict(statistics), values) if return_values else dict(statistics) 3058 3059 @staticmethod 3060 def remove_project(name: str, projects_path: str = None) -> None: 3061 """Remove all project files and experiment records and results. 3062 3063 Parameters 3064 ---------- 3065 name : str 3066 the name of the project to remove 3067 projects_path : str, optional 3068 the path to the projects directory (by default the home DLC2Action directory) 3069 3070 """ 3071 if projects_path is None: 3072 projects_path = os.path.join(str(Path.home()), "DLC2Action") 3073 project_path = os.path.join(projects_path, name) 3074 if os.path.exists(project_path): 3075 shutil.rmtree(project_path) 3076 3077 def remove_saved_features( 3078 self, 3079 dataset_names: List = None, 3080 exceptions: List = None, 3081 remove_active: bool = False, 3082 ) -> None: 3083 """Remove saved pre-computed dataset feature files. 3084 3085 By default, all features will be deleted. 3086 No essential information can get lost, storing them only saves time. Be careful with deleting datasets 3087 while training or inference is happening though. 3088 3089 Parameters 3090 ---------- 3091 dataset_names : list, optional 3092 a list of dataset names to delete (by default all names are added) 3093 exceptions : list, optional 3094 a list of dataset names to not be deleted 3095 remove_active : bool, default False 3096 if `False`, datasets used by unfinished episodes will not be deleted 3097 3098 """ 3099 print("Removing datasets...") 3100 if dataset_names is None: 3101 dataset_names = [] 3102 if exceptions is None: 3103 exceptions = [] 3104 if not remove_active: 3105 exceptions += self._episodes().get_active_datasets() 3106 dataset_path = os.path.join(self.project_path, "saved_datasets") 3107 if os.path.exists(dataset_path): 3108 if dataset_names == []: 3109 dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)]) 3110 3111 to_remove = [ 3112 x 3113 for x in dataset_names 3114 if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions 3115 ] 3116 if len(to_remove) > 2: 3117 to_remove = tqdm(to_remove) 3118 for dataset in to_remove: 3119 shutil.rmtree(os.path.join(dataset_path, dataset)) 3120 to_remove = [ 3121 f"{x}.pickle" 3122 for x in dataset_names 3123 if os.path.exists(os.path.join(dataset_path, f"{x}.pickle")) 3124 and x not in exceptions 3125 ] 3126 for dataset in to_remove: 3127 os.remove(os.path.join(dataset_path, dataset)) 3128 names = self._saved_datasets().dataset_names() 3129 self._saved_datasets().remove(names) 3130 print("\n") 3131 3132 def remove_extra_checkpoints( 3133 self, episode_names: List = None, exceptions: List = None 3134 ) -> None: 3135 """Remove intermediate model checkpoint files (only leave the files for the last epoch). 3136 3137 By default, all intermediate checkpoints will be deleted. 3138 Files in the model folder that are not associated with any record in the meta files are also deleted. 3139 3140 Parameters 3141 ---------- 3142 episode_names : list, optional 3143 a list of episode names to clean (by default all names are added) 3144 exceptions : list, optional 3145 a list of episode names to not clean 3146 3147 """ 3148 model_path = os.path.join(self.project_path, "results", "model") 3149 try: 3150 all_names = self._episodes().data.index 3151 except: 3152 all_names = os.listdir(model_path) 3153 if episode_names is None: 3154 episode_names = all_names 3155 if exceptions is None: 3156 exceptions = [] 3157 to_remove = [x for x in episode_names if x not in exceptions] 3158 folders = os.listdir(model_path) 3159 for folder in folders: 3160 if folder not in all_names: 3161 shutil.rmtree(os.path.join(model_path, folder)) 3162 elif folder in to_remove: 3163 files = os.listdir(os.path.join(model_path, folder)) 3164 for file in sorted(files)[:-1]: 3165 os.remove(os.path.join(model_path, folder, file)) 3166 3167 def remove_search(self, search_name: str) -> None: 3168 """Remove a hyperparameter search record. 3169 3170 Parameters 3171 ---------- 3172 search_name : str 3173 the name of the search to remove 3174 3175 """ 3176 self._searches().remove_episode(search_name) 3177 graph_path = os.path.join(self.project_path, "results", "searches", search_name) 3178 if os.path.exists(graph_path): 3179 shutil.rmtree(graph_path) 3180 3181 def remove_suggestion(self, suggestion_name: str) -> None: 3182 """Remove a suggestion record. 3183 3184 Parameters 3185 ---------- 3186 suggestion_name : str 3187 the name of the suggestion to remove 3188 3189 """ 3190 self._suggestions().remove_episode(suggestion_name) 3191 suggestion_path = os.path.join( 3192 self.project_path, "results", "suggestions", suggestion_name 3193 ) 3194 if os.path.exists(suggestion_path): 3195 shutil.rmtree(suggestion_path) 3196 3197 def remove_prediction(self, prediction_name: str) -> None: 3198 """Remove a prediction record. 3199 3200 Parameters 3201 ---------- 3202 prediction_name : str 3203 the name of the prediction to remove 3204 3205 """ 3206 self._predictions().remove_episode(prediction_name) 3207 prediction_path = self.prediction_path(prediction_name) 3208 if os.path.exists(prediction_path): 3209 shutil.rmtree(prediction_path) 3210 3211 def check_prediction_exists(self, prediction_name: str) -> str | None: 3212 """Check if a prediction exists. 3213 3214 Parameters 3215 ---------- 3216 prediction_name : str 3217 the name of the prediction to check 3218 3219 Returns 3220 ------- 3221 str | None 3222 the path to the prediction if it exists, `None` otherwise 3223 3224 """ 3225 prediction_path = self.prediction_path(prediction_name) 3226 if os.path.exists(prediction_path): 3227 return prediction_path 3228 return None 3229 3230 def remove_episode(self, episode_name: str) -> None: 3231 """Remove all model, logs and metafile records related to an episode. 3232 3233 Parameters 3234 ---------- 3235 episode_name : str 3236 the name of the episode to remove 3237 3238 """ 3239 runs = self._episodes().get_runs(episode_name) 3240 runs.append(episode_name) 3241 for run in runs: 3242 self._episodes().remove_episode(run) 3243 model_path = os.path.join(self.project_path, "results", "model", run) 3244 if os.path.exists(model_path): 3245 shutil.rmtree(model_path) 3246 log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt") 3247 if os.path.exists(log_path): 3248 os.remove(log_path) 3249 3250 @abstractmethod 3251 def _reformat_results(res: dict, classes: dict, exclusive=False): 3252 """Add classes to micro metrics in results from evaluation""" 3253 results = deepcopy(res) 3254 for key in results.keys(): 3255 if isinstance(results[key], list): 3256 if exclusive and len(classes) == len(results[key]) + 1: 3257 other_ind = list(classes.keys())[ 3258 list(classes.values()).index("other") 3259 ] 3260 classes = { 3261 (i if i < other_ind else i - 1): c 3262 for i, c in classes.items() 3263 if i != other_ind 3264 } 3265 assert len(results[key]) == len( 3266 classes 3267 ), f"Results for {key} have {len(results[key])} values, but {len(classes)} classes were provided!" 3268 results[key] = { 3269 classes[i]: float(v) for i, v in enumerate(results[key]) 3270 } 3271 return results 3272 3273 def prune_unfinished(self, exceptions: List = None) -> List: 3274 """Remove all interrupted episodes. 3275 3276 Remove all episodes that either don't have a log file or have less epochs in the log file than in 3277 the training parameters or have a model folder but not a record. Note that it can remove episodes that are 3278 currently running! 3279 3280 Parameters 3281 ---------- 3282 exceptions : list 3283 the episodes to keep even if they are interrupted 3284 3285 Returns 3286 ------- 3287 pruned : list 3288 a list of the episode names that were pruned 3289 3290 """ 3291 if exceptions is None: 3292 exceptions = [] 3293 unfinished = self._episodes().unfinished_episodes() 3294 unfinished = [x for x in unfinished if x not in exceptions] 3295 model_folders = os.listdir(os.path.join(self.project_path, "results", "model")) 3296 unfinished += [ 3297 x for x in model_folders if x not in self._episodes().list_episodes().index 3298 ] 3299 print(f"PRUNING {unfinished}") 3300 for episode_name in unfinished: 3301 self.remove_episode(episode_name) 3302 print(f"\n") 3303 return unfinished 3304 3305 def prediction_path(self, prediction_name: str) -> str: 3306 """Get the path where prediction files are saved. 3307 3308 Parameters 3309 ---------- 3310 prediction_name : str 3311 name of the prediction 3312 3313 Returns 3314 ------- 3315 prediction_path : str 3316 the file path 3317 3318 """ 3319 return os.path.join( 3320 self.project_path, "results", "predictions", f"{prediction_name}" 3321 ) 3322 3323 def suggestion_path(self, suggestion_name: str) -> str: 3324 """Get the path where suggestion files are saved. 3325 3326 Parameters 3327 ---------- 3328 suggestion_name : str 3329 name of the prediction 3330 3331 Returns 3332 ------- 3333 suggestion_path : str 3334 the file path 3335 3336 """ 3337 return os.path.join( 3338 self.project_path, "results", "suggestions", f"{suggestion_name}" 3339 ) 3340 3341 @classmethod 3342 def print_data_types(cls): 3343 """Print available data types.""" 3344 print("DATA TYPES:") 3345 for key, value in cls.data_types().items(): 3346 print(f"{key}:") 3347 print(value.__doc__) 3348 3349 @classmethod 3350 def print_annotation_types(cls): 3351 """Print available annotation types.""" 3352 print("ANNOTATION TYPES:") 3353 for key, value in cls.annotation_types(): 3354 print(f"{key}:") 3355 print(value.__doc__) 3356 3357 @staticmethod 3358 def data_types() -> List: 3359 """Get available data types. 3360 3361 Returns 3362 ------- 3363 data_types : list 3364 available data types 3365 3366 """ 3367 return options.input_stores 3368 3369 @staticmethod 3370 def annotation_types() -> List: 3371 """Get available annotation types. 3372 3373 Returns 3374 ------- 3375 list 3376 available annotation types 3377 3378 """ 3379 return options.annotation_stores 3380 3381 def _save_mask(self, file: Dict, mask_name: str): 3382 """Save a mask file. 3383 3384 Parameters 3385 ---------- 3386 file : dict 3387 the mask file data to save 3388 mask_name : str 3389 the name of the mask file 3390 3391 """ 3392 if not os.path.exists(self._mask_path()): 3393 os.mkdir(self._mask_path()) 3394 with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f: 3395 pickle.dump(file, f) 3396 3397 def _load_mask(self, mask_name: str) -> Dict: 3398 """Load a mask file. 3399 3400 Parameters 3401 ---------- 3402 mask_name : str 3403 the name of the mask file to load 3404 3405 Returns 3406 ------- 3407 mask : dict 3408 the loaded mask data 3409 3410 """ 3411 with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f: 3412 data = pickle.load(f) 3413 return data 3414 3415 def _thresholds(self) -> DecisionThresholds: 3416 """Get the decision thresholds meta object. 3417 3418 Returns 3419 ------- 3420 thresholds : DecisionThresholds 3421 the decision thresholds meta object 3422 3423 """ 3424 return DecisionThresholds(self._thresholds_path()) 3425 3426 def _episodes(self) -> SavedRuns: 3427 """Get the episodes meta object. 3428 3429 Returns 3430 ------- 3431 episodes : SavedRuns 3432 the episodes meta object 3433 3434 """ 3435 try: 3436 return SavedRuns(self._episodes_path(), self.project_path) 3437 except: 3438 self.load_metadata_backup() 3439 return SavedRuns(self._episodes_path(), self.project_path) 3440 3441 def _suggestions(self) -> Suggestions: 3442 """Get the suggestions meta object. 3443 3444 Returns 3445 ------- 3446 suggestions : Suggestions 3447 the suggestions meta object 3448 3449 """ 3450 try: 3451 return Suggestions(self._suggestions_path(), self.project_path) 3452 except: 3453 self.load_metadata_backup() 3454 return Suggestions(self._suggestions_path(), self.project_path) 3455 3456 def _predictions(self) -> SavedRuns: 3457 """Get the predictions meta object. 3458 3459 Returns 3460 ------- 3461 predictions : SavedRuns 3462 the predictions meta object 3463 3464 """ 3465 try: 3466 return SavedRuns(self._predictions_path(), self.project_path) 3467 except: 3468 self.load_metadata_backup() 3469 return SavedRuns(self._predictions_path(), self.project_path) 3470 3471 def _saved_datasets(self) -> SavedStores: 3472 """Get the datasets meta object. 3473 3474 Returns 3475 ------- 3476 datasets : SavedStores 3477 the datasets meta object 3478 3479 """ 3480 try: 3481 return SavedStores(self._saved_datasets_path()) 3482 except: 3483 self.load_metadata_backup() 3484 return SavedStores(self._saved_datasets_path()) 3485 3486 def _prediction(self, name: str) -> Run: 3487 """Get a prediction meta object. 3488 3489 Parameters 3490 ---------- 3491 name : str 3492 episode name 3493 3494 Returns 3495 ------- 3496 prediction : Run 3497 the prediction meta object 3498 3499 """ 3500 try: 3501 return Run(name, self.project_path, meta_path=self._predictions_path()) 3502 except: 3503 self.load_metadata_backup() 3504 return Run(name, self.project_path, meta_path=self._predictions_path()) 3505 3506 def _episode(self, name: str) -> Run: 3507 """Get an episode meta object. 3508 3509 Parameters 3510 ---------- 3511 name : str 3512 episode name 3513 3514 Returns 3515 ------- 3516 episode : Run 3517 the episode meta object 3518 3519 """ 3520 try: 3521 return Run(name, self.project_path, meta_path=self._episodes_path()) 3522 except: 3523 self.load_metadata_backup() 3524 return Run(name, self.project_path, meta_path=self._episodes_path()) 3525 3526 def _searches(self) -> Searches: 3527 """Get the hyperparameter search meta object. 3528 3529 Returns 3530 ------- 3531 searches : Searches 3532 the searches meta object 3533 3534 """ 3535 try: 3536 return Searches(self._searches_path(), self.project_path) 3537 except: 3538 self.load_metadata_backup() 3539 return Searches(self._searches_path(), self.project_path) 3540 3541 def _update_configs(self) -> None: 3542 """Update the project config files with newly added files and parameters. 3543 3544 This method updates the project configuration with the data path and copies 3545 any new configuration files from the original package to the project. 3546 3547 """ 3548 self.update_parameters({"data": {"data_path": self.data_path}}) 3549 folders = ["augmentations", "features", "model"] 3550 original_path = os.path.join( 3551 os.path.dirname(os.path.dirname(__file__)), "config" 3552 ) 3553 project_path = os.path.join(self.project_path, "config") 3554 filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")] 3555 for folder in folders: 3556 filenames += [ 3557 os.path.join(folder, x) 3558 for x in os.listdir(os.path.join(original_path, folder)) 3559 ] 3560 filenames.append(os.path.join("data", f"{self.data_type}.yaml")) 3561 if self.annotation_type != "none": 3562 filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml")) 3563 for file in filenames: 3564 filepath_original = os.path.join(original_path, file) 3565 if file.startswith("data") or file.startswith("annotation"): 3566 file = os.path.basename(file) 3567 filepath_project = os.path.join(project_path, file) 3568 if not os.path.exists(filepath_project): 3569 shutil.copy(filepath_original, filepath_project) 3570 else: 3571 original_pars = self._open_yaml(filepath_original) 3572 project_pars = self._open_yaml(filepath_project) 3573 to_remove = [] 3574 for key, value in project_pars.items(): 3575 if key not in original_pars: 3576 if key not in ["data_type", "annotation_type"]: 3577 to_remove.append(key) 3578 for key in to_remove: 3579 project_pars.pop(key) 3580 to_remove = [] 3581 for key, value in original_pars.items(): 3582 if key in project_pars: 3583 to_remove.append(key) 3584 for key in to_remove: 3585 original_pars.pop(key) 3586 project_pars = self._update(project_pars, original_pars) 3587 with open(filepath_project, "w", encoding="utf-8") as f: 3588 YAML().dump(project_pars, f) 3589 3590 def _update_project(self) -> None: 3591 """Update project files with the current version.""" 3592 version_file = self._version_path() 3593 ok = True 3594 if not os.path.exists(version_file): 3595 ok = False 3596 else: 3597 with open(version_file, encoding="utf-8") as f: 3598 project_version = f.read() 3599 if project_version < __version__: 3600 ok = False 3601 elif project_version > __version__: 3602 warnings.warn( 3603 f"The project expects a higher dlc2action version ({project_version}), please update!" 3604 ) 3605 if not ok: 3606 project_config_path = os.path.join(self.project_path, "config") 3607 config_path = os.path.join( 3608 os.path.dirname(os.path.dirname(__path__)), "config" 3609 ) 3610 episodes = self._episodes() 3611 folders = ["annotation", "augmentations", "data", "features", "model"] 3612 3613 project_annotation_configs = os.listdir( 3614 os.path.join(project_config_path, "annotation") 3615 ) 3616 annotation_configs = os.listdir(os.path.join(config_path, "annotation")) 3617 for ann_config in annotation_configs: 3618 if ann_config not in project_annotation_configs: 3619 shutil.copytree( 3620 os.path.join(config_path, "annotation", ann_config), 3621 os.path.join(project_config_path, "annotation", ann_config), 3622 dirs_exist_ok=True, 3623 ) 3624 else: 3625 project_pars = self._open_yaml( 3626 os.path.join(project_config_path, "annotation", ann_config) 3627 ) 3628 pars = self._open_yaml( 3629 os.path.join(config_path, "annotation", ann_config) 3630 ) 3631 new_keys = set(pars.keys()) - set(project_pars.keys()) 3632 for key in new_keys: 3633 project_pars[key] = pars[key] 3634 c = self._get_comment(pars.ca.items.get(key)) 3635 project_pars.yaml_add_eol_comment(c, key=key) 3636 episodes.update( 3637 condition=f"general/annotation_type::={ann_config}", 3638 update={f"data/{key}": pars[key]}, 3639 ) 3640 3641 def _initialize_project( 3642 self, 3643 data_type: str, 3644 annotation_type: str = None, 3645 data_path: str = None, 3646 annotation_path: str = None, 3647 copy: bool = True, 3648 ) -> None: 3649 """Initialize a new project.""" 3650 if data_type not in self.data_types(): 3651 raise ValueError( 3652 f"The {data_type} data type is not available. " 3653 f"Please choose from {self.data_types()}" 3654 ) 3655 if annotation_type not in self.annotation_types(): 3656 raise ValueError( 3657 f"The {annotation_type} annotation type is not available. " 3658 f"Please choose from {self.annotation_types()}" 3659 ) 3660 os.mkdir(self.project_path) 3661 folders = ["results", "saved_datasets", "meta", "config"] 3662 for f in folders: 3663 os.mkdir(os.path.join(self.project_path, f)) 3664 results_subfolders = [ 3665 "model", 3666 "logs", 3667 "predictions", 3668 "splits", 3669 "searches", 3670 "suggestions", 3671 ] 3672 for sf in results_subfolders: 3673 os.mkdir(os.path.join(self.project_path, "results", sf)) 3674 if data_path is not None: 3675 if copy: 3676 os.mkdir(os.path.join(self.project_path, "data")) 3677 shutil.copytree( 3678 data_path, 3679 os.path.join(self.project_path, "data"), 3680 dirs_exist_ok=True, 3681 ) 3682 data_path = os.path.join(self.project_path, "data") 3683 if annotation_path is not None: 3684 if copy: 3685 os.mkdir(os.path.join(self.project_path, "annotation")) 3686 shutil.copytree( 3687 annotation_path, 3688 os.path.join(self.project_path, "annotation"), 3689 dirs_exist_ok=True, 3690 ) 3691 annotation_path = os.path.join(self.project_path, "annotation") 3692 self._generate_config( 3693 data_type, 3694 annotation_type, 3695 data_path=data_path, 3696 annotation_path=annotation_path, 3697 ) 3698 self._generate_meta() 3699 3700 def _read_types(self) -> Tuple[str, str]: 3701 """Get data type and annotation type from existing project files.""" 3702 config_path = os.path.join(self.project_path, "config", "general.yaml") 3703 with open(config_path, encoding="utf-8") as f: 3704 pars = YAML().load(f) 3705 data_type = pars["data_type"] 3706 annotation_type = pars["annotation_type"] 3707 return annotation_type, data_type 3708 3709 def _read_paths(self) -> Tuple[str, str]: 3710 """Get data type and annotation type from existing project files.""" 3711 config_path = os.path.join(self.project_path, "config", "data.yaml") 3712 with open(config_path, encoding="utf-8") as f: 3713 pars = YAML().load(f) 3714 data_path = pars["data_path"] 3715 annotation_path = pars["annotation_path"] 3716 return annotation_path, data_path 3717 3718 def _generate_config( 3719 self, data_type: str, annotation_type: str, data_path: str, annotation_path: str 3720 ) -> None: 3721 """Initialize the config files.""" 3722 default_path = os.path.join( 3723 os.path.dirname(os.path.dirname(__file__)), "config" 3724 ) 3725 config_path = os.path.join(self.project_path, "config") 3726 files = ["losses", "metrics", "ssl", "training"] 3727 for f in files: 3728 shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path) 3729 shutil.copytree( 3730 os.path.join(default_path, "model"), os.path.join(config_path, "model") 3731 ) 3732 shutil.copytree( 3733 os.path.join(default_path, "features"), 3734 os.path.join(config_path, "features"), 3735 ) 3736 shutil.copytree( 3737 os.path.join(default_path, "augmentations"), 3738 os.path.join(config_path, "augmentations"), 3739 ) 3740 yaml = YAML() 3741 data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml") 3742 if os.path.exists(data_param_path): 3743 with open(data_param_path, encoding="utf-8") as f: 3744 data_params = yaml.load(f) 3745 if data_params is None: 3746 data_params = {} 3747 if annotation_type is None: 3748 ann_params = {} 3749 else: 3750 ann_param_path = os.path.join( 3751 default_path, "annotation", f"{annotation_type}.yaml" 3752 ) 3753 if os.path.exists(ann_param_path): 3754 ann_params = self._open_yaml(ann_param_path) 3755 elif annotation_type == "none": 3756 ann_params = {} 3757 else: 3758 raise ValueError( 3759 f"The {annotation_type} data type is not available. " 3760 f"Please choose from {BehaviorDataset.annotation_types()}" 3761 ) 3762 if ann_params is None: 3763 ann_params = {} 3764 data_params = self._update(data_params, ann_params) 3765 data_params["data_path"] = data_path 3766 data_params["annotation_path"] = annotation_path 3767 with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f: 3768 yaml.dump(data_params, f) 3769 with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f: 3770 general_params = yaml.load(f) 3771 general_params["data_type"] = data_type 3772 general_params["annotation_type"] = annotation_type 3773 with open( 3774 os.path.join(config_path, "general.yaml"), "w", encoding="utf-8" 3775 ) as f: 3776 yaml.dump(general_params, f) 3777 3778 def _generate_meta(self) -> None: 3779 """Initialize the meta files.""" 3780 config_file = os.path.join(self.project_path, "config") 3781 meta_fields = ["time"] 3782 columns = [("meta", field) for field in meta_fields] 3783 episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 3784 episodes.to_pickle(self._episodes_path()) 3785 meta_fields = ["time", "objective"] 3786 result_fields = ["best_params", "best_value"] 3787 columns = [("meta", field) for field in meta_fields] + [ 3788 ("results", field) for field in result_fields 3789 ] 3790 searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 3791 searches.to_pickle(self._searches_path()) 3792 meta_fields = ["time"] 3793 columns = [("meta", field) for field in meta_fields] 3794 predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 3795 predictions.to_pickle(self._predictions_path()) 3796 suggestions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 3797 suggestions.to_pickle(self._suggestions_path()) 3798 with open(os.path.join(config_file, "data.yaml"), encoding="utf-8") as f: 3799 data_keys = list(YAML().load(f).keys()) 3800 saved_data = pd.DataFrame(columns=data_keys) 3801 saved_data.to_pickle(self._saved_datasets_path()) 3802 pd.DataFrame().to_pickle(self._thresholds_path()) 3803 # with open(self._version_path()) as f: 3804 # f.write(__version__) 3805 3806 def _open_yaml(self, path: str) -> CommentedMap: 3807 """Load a parameter dictionary from a .yaml file.""" 3808 with open(path, encoding="utf-8") as f: 3809 data = YAML().load(f) 3810 if data is None: 3811 data = {} 3812 return data 3813 3814 def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7): 3815 """Compare nested dictionaries with 'almost equal' condition.""" 3816 ok = True 3817 if u.keys() != d.keys(): 3818 ok = False 3819 else: 3820 for k, v in u.items(): 3821 if isinstance(v, Mapping): 3822 ok = self._compare(d[k], v, allow_diff=allow_diff) 3823 else: 3824 if isinstance(v, float) or isinstance(d[k], float): 3825 if not isinstance(d[k], float) and not isinstance(d[k], int): 3826 ok = False 3827 elif not isinstance(v, float) and not isinstance(v, int): 3828 ok = False 3829 elif np.abs(v - d[k]) > allow_diff: 3830 ok = False 3831 elif v != d[k]: 3832 ok = False 3833 return ok 3834 3835 def _check_comment(self, comment_sequence: List) -> bool: 3836 """Check if a comment already exists in a ruamel.yaml comment sequence.""" 3837 if comment_sequence is None: 3838 return False 3839 c = self._get_comment(comment_sequence) 3840 if c != "": 3841 return True 3842 else: 3843 return False 3844 3845 def _get_comment(self, comment_sequence: List, strip=True) -> str: 3846 """Get the comment string from a ruamel.yaml comment sequence.""" 3847 if comment_sequence is None: 3848 return "" 3849 c = "" 3850 for cm in comment_sequence: 3851 if cm is not None: 3852 if isinstance(cm, Iterable): 3853 for c in cm: 3854 if c is not None: 3855 c = c.value 3856 break 3857 break 3858 else: 3859 c = cm.value 3860 break 3861 if strip: 3862 c = c.strip() 3863 return c 3864 3865 def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]): 3866 """Update a nested dictionary.""" 3867 if "general" in u and "model_name" in u["general"] and "model" in d: 3868 model_name = u["general"]["model_name"] 3869 if d["general"]["model_name"] != model_name: 3870 d["model"] = self._open_yaml( 3871 os.path.join( 3872 self.project_path, "config", "model", f"{model_name}.yaml" 3873 ) 3874 ) 3875 d_copied = deepcopy(d) 3876 for k, v in u.items(): 3877 if ( 3878 k in d_copied 3879 and isinstance(d_copied[k], list) 3880 and isinstance(v, Mapping) 3881 and all([isinstance(x, int) for x in v.keys()]) 3882 ): 3883 for kk, vv in v.items(): 3884 d_copied[k][kk] = vv 3885 elif ( 3886 isinstance(v, Mapping) 3887 and k in d_copied 3888 and isinstance(d_copied[k], Mapping) 3889 ): 3890 if d_copied[k] is None: 3891 d_k = CommentedMap() 3892 else: 3893 d_k = d_copied[k] 3894 d_copied[k] = self._update(d_k, v) 3895 else: 3896 d_copied[k] = v 3897 if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None: 3898 c = self._get_comment(u.ca.items.get(k), strip=False) 3899 if isinstance(d_copied, CommentedMap) and not self._check_comment( 3900 d_copied.ca.items.get(k) 3901 ): 3902 d_copied.yaml_add_eol_comment(c, key=k) 3903 return d_copied 3904 3905 def _update_with_search( 3906 self, 3907 d: Dict, 3908 search_name: str, 3909 load_parameters: list = None, 3910 round_to_binary: list = None, 3911 ): 3912 """Update a dictionary with best parameters from a hyperparameter search.""" 3913 u, _ = self._searches().get_best_params( 3914 search_name, load_parameters, round_to_binary 3915 ) 3916 return self._update(d, u) 3917 3918 def _read_parameters(self, catch_blanks=True) -> Dict: 3919 """Compose a parameter dictionary to create a task from the config files.""" 3920 config_path = os.path.join(self.project_path, "config") 3921 keys = [ 3922 "data", 3923 "general", 3924 "losses", 3925 "metrics", 3926 "ssl", 3927 "training", 3928 ] 3929 parameters = {} 3930 for key in keys: 3931 parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml")) 3932 features = parameters["general"]["feature_extraction"] 3933 parameters["features"] = self._open_yaml( 3934 os.path.join(config_path, "features", f"{features}.yaml") 3935 ) 3936 transformer = options.extractor_to_transformer[features] 3937 parameters["augmentations"] = self._open_yaml( 3938 os.path.join(config_path, "augmentations", f"{transformer}.yaml") 3939 ) 3940 model = parameters["general"]["model_name"] 3941 parameters["model"] = self._open_yaml( 3942 os.path.join(config_path, "model", f"{model}.yaml") 3943 ) 3944 # input = parameters["general"]["input"] 3945 # parameters["model"] = self._open_yaml( 3946 # os.path.join(config_path, "model", f"{model}.yaml") 3947 # ) 3948 if catch_blanks: 3949 blanks = self._get_blanks() 3950 if len(blanks) > 0: 3951 self.list_blanks() 3952 raise ValueError( 3953 f"Please fill in all the blanks before running experiments" 3954 ) 3955 return parameters 3956 3957 def set_main_parameters(self, model_name: str = None, metric_names: List = None): 3958 """Select the model and the metrics. 3959 3960 Parameters 3961 ---------- 3962 model_name : str, optional 3963 model name; run `project.help("model") to find out more 3964 metric_names : list, optional 3965 a list of metric function names; run `project.help("metrics") to find out more 3966 3967 """ 3968 pars = {"general": {}} 3969 if model_name is not None: 3970 assert model_name in options.models 3971 pars["general"]["model_name"] = model_name 3972 if metric_names is not None: 3973 for metric in metric_names: 3974 assert metric in options.metrics 3975 pars["general"]["metric_functions"] = metric_names 3976 self.update_parameters(pars) 3977 3978 def help(self, keyword: str = None): 3979 """Get information on available options. 3980 3981 Parameters 3982 ---------- 3983 keyword : str, optional 3984 the keyword for options (run without arguments to see which keywords are available) 3985 3986 """ 3987 if keyword is None: 3988 print("AVAILABLE HELP FUNCTIONS:") 3989 print("- Try running `project.help(keyword)` with the following keywords:") 3990 print(" - model: to get more information on available models,") 3991 print( 3992 " - features: to get more information on available feature extraction modes," 3993 ) 3994 print( 3995 " - partition_method: to get more information on available train/test/val partitioning methods," 3996 ) 3997 print(" - metrics: to see a list of available metric functions.") 3998 print(" - data: to see help for expected data structure") 3999 print( 4000 "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in." 4001 ) 4002 print( 4003 "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify" 4004 ) 4005 print( 4006 f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)." 4007 ) 4008 elif keyword == "model": 4009 print("MODELS:") 4010 for key, model in options.models.items(): 4011 print(f"{key}:") 4012 print(model.__doc__) 4013 elif keyword == "features": 4014 print("FEATURE EXTRACTORS:") 4015 for key, extractor in options.feature_extractors.items(): 4016 print(f"{key}:") 4017 print(extractor.__doc__) 4018 elif keyword == "partition_method": 4019 print("PARTITION METHODS:") 4020 print( 4021 BehaviorDataset.partition_train_test_val.__doc__.split( 4022 "The partitioning method:" 4023 )[1].split("val_frac :")[0] 4024 ) 4025 elif keyword == "metrics": 4026 print("METRICS:") 4027 for key, metric in options.metrics.items(): 4028 print(f"{key}:") 4029 print(metric.__doc__) 4030 elif keyword == "data": 4031 print("DATA:") 4032 print(f"Video data: {self.data_type}") 4033 print(options.input_stores[self.data_type].__doc__) 4034 print(f"Annotation data: {self.annotation_type}") 4035 print(options.annotation_stores[self.annotation_type].__doc__) 4036 print( 4037 "Annotation path and data path don't have to be separate, you can keep everything in one folder." 4038 ) 4039 else: 4040 raise ValueError(f"The {keyword} keyword is not recognized") 4041 print("\n") 4042 4043 def _process_value(self, value): 4044 """Process a configuration value for display. 4045 4046 Parameters 4047 ---------- 4048 value : any 4049 the value to process 4050 4051 Returns 4052 ------- 4053 processed_value : any 4054 the processed value 4055 4056 """ 4057 if isinstance(value, str): 4058 value = f'"{value}"' 4059 elif isinstance(value, CommentedSet): 4060 value = {x for x in value} 4061 return value 4062 4063 def _get_blanks(self): 4064 """Get a list of blank (unset) parameters in the configuration. 4065 4066 Returns 4067 ------- 4068 caught : list 4069 a list of parameter keys that have blank values 4070 4071 """ 4072 caught = [] 4073 parameters = self._read_parameters(catch_blanks=False) 4074 for big_key, big_value in parameters.items(): 4075 for key, value in big_value.items(): 4076 if value == "???": 4077 caught.append( 4078 (big_key, key, self._get_comment(big_value.ca.items.get(key))) 4079 ) 4080 return caught 4081 4082 def list_blanks(self, blanks=None): 4083 """List parameters that need to be filled in. 4084 4085 Parameters 4086 ---------- 4087 blanks : list, optional 4088 a list of the parameters to list, if already known 4089 4090 """ 4091 if blanks is None: 4092 blanks = self._get_blanks() 4093 if len(blanks) > 0: 4094 to_update = defaultdict(lambda: []) 4095 for b, k, c in blanks: 4096 to_update[b].append((k, c)) 4097 print("Before running experiments, please update all the blanks.") 4098 print("To do that, you can run this.") 4099 print("--------------------------------------------------------") 4100 print(f"project.update_parameters(") 4101 print(f" {{") 4102 for big_key, keys in to_update.items(): 4103 print(f' "{big_key}": {{') 4104 for key, comment in keys: 4105 print(f' "{key}": ..., {comment}') 4106 print(f" }}") 4107 print(f" }}") 4108 print(")") 4109 print("--------------------------------------------------------") 4110 print("Replace ... with relevant values.") 4111 else: 4112 print("There is no blanks left!") 4113 4114 def list_basic_parameters( 4115 self, 4116 ): 4117 """Get a list of most relevant parameters and code to modify them.""" 4118 parameters = self._read_parameters() 4119 print("BASIC PARAMETERS:") 4120 model_name = parameters["general"]["model_name"] 4121 metric_names = parameters["general"]["metric_functions"] 4122 loss_name = parameters["general"]["loss_function"] 4123 feature_extraction = parameters["general"]["feature_extraction"] 4124 print("Here is a list of current parameters.") 4125 print( 4126 "You can copy this code, change the parameters you want to set and run it to update the project config." 4127 ) 4128 print("--------------------------------------------------------") 4129 print("project.update_parameters(") 4130 print(" {") 4131 for group in ["general", "data", "training"]: 4132 print(f' "{group}": {{') 4133 for key in options.basic_parameters[group]: 4134 if key in parameters[group]: 4135 print( 4136 f' "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}' 4137 ) 4138 print(" },") 4139 print(' "losses": {') 4140 print(f' "{loss_name}": {{') 4141 for key in options.basic_parameters["losses"][loss_name]: 4142 if key in parameters["losses"][loss_name]: 4143 print( 4144 f' "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}' 4145 ) 4146 print(" },") 4147 print(" },") 4148 print(' "metrics": {') 4149 for metric in metric_names: 4150 print(f' "{metric}": {{') 4151 for key in parameters["metrics"][metric]: 4152 print( 4153 f' "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}' 4154 ) 4155 print(" },") 4156 print(" },") 4157 print(' "model": {') 4158 for key in options.basic_parameters["model"][model_name]: 4159 if key in parameters["model"]: 4160 print( 4161 f' "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}' 4162 ) 4163 4164 print(" },") 4165 print(' "features": {') 4166 for key in options.basic_parameters["features"][feature_extraction]: 4167 if key in parameters["features"]: 4168 print( 4169 f' "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}' 4170 ) 4171 4172 print(" },") 4173 print(' "augmentations": {') 4174 for key in options.basic_parameters["augmentations"][feature_extraction]: 4175 if key in parameters["augmentations"]: 4176 print( 4177 f' "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}' 4178 ) 4179 print(" },") 4180 print(" },") 4181 print(")") 4182 print("--------------------------------------------------------") 4183 print("\n") 4184 4185 def _create_record( 4186 self, 4187 episode_name: str, 4188 behaviors_dict: Dict, 4189 load_episode: str = None, 4190 parameters_update: Dict = None, 4191 task: TaskDispatcher = None, 4192 load_epoch: int = None, 4193 load_search: str = None, 4194 load_parameters: list = None, 4195 round_to_binary: list = None, 4196 load_strict: bool = True, 4197 n_seeds: int = 1, 4198 ) -> TaskDispatcher: 4199 """Create a meta data episode record.""" 4200 if episode_name in self._episodes().data.index: 4201 return 4202 if type(n_seeds) is not int or n_seeds < 1: 4203 raise ValueError( 4204 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 4205 ) 4206 if parameters_update is None: 4207 parameters_update = {} 4208 parameters = self._read_parameters() 4209 parameters = self._update(parameters, parameters_update) 4210 if load_search is not None: 4211 parameters = self._update_with_search( 4212 parameters, load_search, load_parameters, round_to_binary 4213 ) 4214 parameters = self._fill( 4215 parameters, 4216 episode_name, 4217 load_episode, 4218 load_epoch=load_epoch, 4219 only_load_model=True, 4220 load_strict=load_strict, 4221 continuing=True, 4222 ) 4223 self._save_episode(episode_name, parameters, behaviors_dict) 4224 return task 4225 4226 def _save_thresholds( 4227 self, 4228 episode_names: List, 4229 metric_name: str, 4230 parameters: Dict, 4231 thresholds: List, 4232 load_epochs: List, 4233 ): 4234 """Save optimal decision thresholds in the meta records.""" 4235 metric_parameters = parameters["metrics"][metric_name] 4236 self._thresholds().save_thresholds( 4237 episode_names, load_epochs, metric_name, metric_parameters, thresholds 4238 ) 4239 4240 def _save_episode( 4241 self, 4242 episode_name: str, 4243 parameters: Dict, 4244 behaviors_dict: Dict, 4245 suppress_validation: bool = False, 4246 training_time: str = None, 4247 norm_stats: Dict = None, 4248 ) -> None: 4249 """Save an episode in the meta files.""" 4250 try: 4251 split_info = self._split_info_from_filename( 4252 parameters["training"]["split_path"] 4253 ) 4254 parameters["training"]["partition_method"] = split_info["partition_method"] 4255 except: 4256 pass 4257 if norm_stats is not None: 4258 norm_stats = dict(norm_stats) 4259 parameters["training"]["stats"] = norm_stats 4260 self._episodes().save_episode( 4261 episode_name, 4262 parameters, 4263 behaviors_dict, 4264 suppress_validation=suppress_validation, 4265 training_time=training_time, 4266 ) 4267 4268 def _save_suggestions( 4269 self, suggestions_name: str, parameters: Dict, meta_parameters: Dict 4270 ) -> None: 4271 """Save a suggestion in the meta files.""" 4272 self._suggestions().save_suggestion( 4273 suggestions_name, parameters, meta_parameters 4274 ) 4275 4276 def _update_episode_results( 4277 self, 4278 episode_name: str, 4279 logs: Tuple, 4280 training_time: str = None, 4281 ) -> None: 4282 """Save the results of a run in the meta files.""" 4283 self._episodes().update_episode_results(episode_name, logs, training_time) 4284 4285 def _save_prediction( 4286 self, 4287 prediction_name: str, 4288 predicted: Dict[str, Dict], 4289 parameters: Dict, 4290 task: TaskDispatcher, 4291 mode: str = "test", 4292 embedding: bool = False, 4293 inference_time: str = None, 4294 behavior_dict: List[Dict[str, Any]] = None, 4295 ) -> None: 4296 """Save a prediction in the meta files.""" 4297 4298 folder = self.prediction_path(prediction_name) 4299 os.mkdir(folder) 4300 for video_id, prediction in predicted.items(): 4301 with open( 4302 os.path.join( 4303 folder, video_id + f"_{prediction_name}_prediction.pickle" 4304 ), 4305 "wb", 4306 ) as f: 4307 prediction["min_frames"], prediction["max_frames"] = task.dataset( 4308 mode 4309 ).get_min_max_frames(video_id) 4310 prediction["classes"] = behavior_dict 4311 pickle.dump(prediction, f) 4312 4313 parameters = self._update( 4314 parameters, 4315 {"meta": {"embedding": embedding, "inference_time": inference_time}}, 4316 ) 4317 self._predictions().save_episode( 4318 prediction_name, parameters, task.behaviors_dict() 4319 ) 4320 4321 def _save_search( 4322 self, 4323 search_name: str, 4324 parameters: Dict, 4325 n_trials: int, 4326 best_params: Dict, 4327 best_value: float, 4328 metric: str, 4329 search_space: Dict, 4330 ) -> None: 4331 """Save a hyperparameter search in the meta files.""" 4332 self._searches().save_search( 4333 search_name, 4334 parameters, 4335 n_trials, 4336 best_params, 4337 best_value, 4338 metric, 4339 search_space, 4340 ) 4341 4342 def _save_stores(self, parameters: Dict) -> None: 4343 """Save a pickled dataset in the meta files.""" 4344 name = os.path.basename(parameters["data"]["feature_save_path"]) 4345 self._saved_datasets().save_store(name, self._get_data_pars(parameters)) 4346 self.create_metadata_backup() 4347 4348 def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None: 4349 """Remove the pre-computed features folder.""" 4350 name = os.path.basename(parameters["data"]["feature_save_path"]) 4351 if remove_active or name not in self._episodes().get_active_datasets(): 4352 self.remove_saved_features([name]) 4353 4354 def _check_episode_validity( 4355 self, episode_name: str, allow_doublecolon: bool = False, force: bool = False 4356 ) -> None: 4357 """Check whether the episode name is valid.""" 4358 if episode_name.startswith("_"): 4359 raise ValueError( 4360 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 4361 ) 4362 elif "." in episode_name: 4363 raise ValueError("Names containing '.' cannot be used!") 4364 if not allow_doublecolon and "#" in episode_name: 4365 raise ValueError( 4366 "Names containing '#' are reserved by dlc2action and cannot be used!" 4367 ) 4368 if "::" in episode_name: 4369 raise ValueError( 4370 "Names containing '::' are reserved by dlc2action and cannot be used!" 4371 ) 4372 if force: 4373 self.remove_episode(episode_name) 4374 elif not self._episodes().check_name_validity(episode_name): 4375 raise ValueError( 4376 f"The {episode_name} name is already taken! Set force=True to overwrite." 4377 ) 4378 4379 def _check_search_validity(self, search_name: str, force: bool = False) -> None: 4380 """Check whether the search name is valid.""" 4381 if search_name.startswith("_"): 4382 raise ValueError( 4383 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 4384 ) 4385 elif "." in search_name: 4386 raise ValueError("Names containing '.' cannot be used!") 4387 if force: 4388 self.remove_search(search_name) 4389 elif not self._searches().check_name_validity(search_name): 4390 raise ValueError(f"The {search_name} name is already taken!") 4391 4392 def _check_prediction_validity( 4393 self, prediction_name: str, force: bool = False 4394 ) -> None: 4395 """Check whether the prediction name is valid.""" 4396 if prediction_name.startswith("_"): 4397 raise ValueError( 4398 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 4399 ) 4400 elif "." in prediction_name: 4401 raise ValueError("Names containing '.' cannot be used!") 4402 if force: 4403 self.remove_prediction(prediction_name) 4404 elif not self._predictions().check_name_validity(prediction_name): 4405 raise ValueError(f"The {prediction_name} name is already taken!") 4406 4407 def _check_suggestions_validity( 4408 self, suggestions_name: str, force: bool = False 4409 ) -> None: 4410 """Check whether the suggestions name is valid.""" 4411 if suggestions_name.startswith("_"): 4412 raise ValueError( 4413 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 4414 ) 4415 elif "." in suggestions_name: 4416 raise ValueError("Names containing '.' cannot be used!") 4417 if force: 4418 self.remove_suggestion(suggestions_name) 4419 elif not self._suggestions().check_name_validity(suggestions_name): 4420 raise ValueError(f"The {suggestions_name} name is already taken!") 4421 4422 def _training_time(self, episode_name: str) -> int: 4423 """Get the training time of an episode in seconds.""" 4424 return self._episode(episode_name).training_time() 4425 4426 def _mask_path(self) -> str: 4427 """Get the path to the masks folder. 4428 4429 Returns 4430 ------- 4431 path : str 4432 the path to the masks folder 4433 4434 """ 4435 return os.path.join(self.project_path, "results", "masks") 4436 4437 def _thresholds_path(self) -> str: 4438 """Get the path to the thresholds meta file. 4439 4440 Returns 4441 ------- 4442 path : str 4443 the path to the thresholds meta file 4444 4445 """ 4446 return os.path.join(self.project_path, "meta", "thresholds.pickle") 4447 4448 def _episodes_path(self) -> str: 4449 """Get the path to the episodes meta file. 4450 4451 Returns 4452 ------- 4453 path : str 4454 the path to the episodes meta file 4455 4456 """ 4457 return os.path.join(self.project_path, "meta", "episodes.pickle") 4458 4459 def _suggestions_path(self) -> str: 4460 """Get the path to the suggestions meta file. 4461 4462 Returns 4463 ------- 4464 path : str 4465 the path to the suggestions meta file 4466 4467 """ 4468 return os.path.join(self.project_path, "meta", "suggestions.pickle") 4469 4470 def _saved_datasets_path(self) -> str: 4471 """Get the path to the datasets meta file. 4472 4473 Returns 4474 ------- 4475 path : str 4476 the path to the datasets meta file 4477 4478 """ 4479 return os.path.join(self.project_path, "meta", "saved_datasets.pickle") 4480 4481 def _predictions_path(self) -> str: 4482 """Get the path to the predictions meta file. 4483 4484 Returns 4485 ------- 4486 path : str 4487 the path to the predictions meta file 4488 4489 """ 4490 return os.path.join(self.project_path, "meta", "predictions.pickle") 4491 4492 def _dataset_store_path(self, name: str) -> str: 4493 """Get the path to a specific pickled dataset. 4494 4495 Parameters 4496 ---------- 4497 name : str 4498 the name of the dataset 4499 4500 Returns 4501 ------- 4502 path : str 4503 the path to the dataset file 4504 4505 """ 4506 return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle") 4507 4508 def _al_points_path(self, suggestions_name: str) -> str: 4509 """Get the path to an active learning intervals file. 4510 4511 Parameters 4512 ---------- 4513 suggestions_name : str 4514 the name of the suggestions 4515 4516 Returns 4517 ------- 4518 path : str 4519 the path to the active learning points file 4520 4521 """ 4522 path = os.path.join( 4523 self.project_path, 4524 "results", 4525 "suggestions", 4526 suggestions_name, 4527 f"{suggestions_name}_al_points.pickle", 4528 ) 4529 return path 4530 4531 def _suggestion_path(self, v_id: str, suggestions_name: str) -> str: 4532 """Get the path to a suggestion file. 4533 4534 Parameters 4535 ---------- 4536 v_id : str 4537 the video ID 4538 suggestions_name : str 4539 the name of the suggestions 4540 4541 Returns 4542 ------- 4543 path : str 4544 the path to the suggestion file 4545 4546 """ 4547 path = os.path.join( 4548 self.project_path, 4549 "results", 4550 "suggestions", 4551 suggestions_name, 4552 f"{v_id}_suggestion.pickle", 4553 ) 4554 return path 4555 4556 def _searches_path(self) -> str: 4557 """Get the path to the hyperparameter search meta file. 4558 4559 Returns 4560 ------- 4561 path : str 4562 the path to the searches meta file 4563 4564 """ 4565 return os.path.join(self.project_path, "meta", "searches.pickle") 4566 4567 def _search_path(self, name: str) -> str: 4568 """Get the default path to the graph folder for a specific hyperparameter search. 4569 4570 Parameters 4571 ---------- 4572 name : str 4573 the name of the search 4574 4575 Returns 4576 ------- 4577 path : str 4578 the path to the search folder 4579 4580 """ 4581 return os.path.join(self.project_path, "results", "searches", name) 4582 4583 def _version_path(self) -> str: 4584 """Get the path to the version file. 4585 4586 Returns 4587 ------- 4588 path : str 4589 the path to the version file 4590 4591 """ 4592 return os.path.join(self.project_path, "meta", "version.txt") 4593 4594 def _default_split_file(self, split_info: Dict) -> Optional[str]: 4595 """Generate a path to a split file from split parameters. 4596 4597 Parameters 4598 ---------- 4599 split_info : dict 4600 the split parameters dictionary 4601 4602 Returns 4603 ------- 4604 split_file_path : str or None 4605 the path to the split file, or None if not applicable 4606 4607 """ 4608 if split_info["partition_method"].startswith("time"): 4609 return None 4610 val_frac = split_info["val_frac"] 4611 test_frac = split_info["test_frac"] 4612 split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}' 4613 if not split_info["only_load_annotated"]: 4614 split_name += "_all" 4615 split_name += ".txt" 4616 return os.path.join(self.project_path, "results", "splits", split_name) 4617 4618 def _split_info_from_filename(self, split_name: str) -> Dict: 4619 """Get split parameters from default path to a split file. 4620 4621 Parameters 4622 ---------- 4623 split_name : str 4624 the name/path of the split file 4625 4626 Returns 4627 ------- 4628 split_info : dict 4629 the split parameters dictionary 4630 4631 """ 4632 if split_name is None: 4633 return {} 4634 try: 4635 name = os.path.basename(split_name)[:-4] 4636 split = name.split("_") 4637 if len(split) == 6: 4638 only_load_annotated = False 4639 else: 4640 only_load_annotated = True 4641 len_segment = int(split[3][3:]) 4642 overlap = float(split[4][7:]) 4643 if overlap > 1: 4644 overlap = int(overlap) 4645 method, val, test = split[:3] 4646 val = float(val[3:-1]) / 100 4647 test = float(test[4:-1]) / 100 4648 return { 4649 "partition_method": method, 4650 "val_frac": val, 4651 "test_frac": test, 4652 "only_load_annotated": only_load_annotated, 4653 "len_segment": len_segment, 4654 "overlap": overlap, 4655 } 4656 except: 4657 return {"partition_method": "file"} 4658 4659 def _fill( 4660 self, 4661 parameters: Dict, 4662 episode_name: str, 4663 load_experiment: str = None, 4664 load_epoch: int = None, 4665 load_strict: bool = True, 4666 only_load_model: bool = False, 4667 continuing: bool = False, 4668 enforce_split_parameters: bool = False, 4669 ) -> Dict: 4670 """Update the parameters from the config files with project specific information. 4671 4672 Fill in the constant file path parameters and generate a unique log file and a model folder. 4673 Fill in the split file if the same split has been run before in the project and change partition method to 4674 from_file. 4675 Fill in saved data path if a dataset with the same data parameters already exists in the project. 4676 If load_experiment is not None, fill in the checkpoint path as well. 4677 The only_load_model training parameter is defined by the corresponding argument. 4678 If continuing is True, new files are not created and all information is loaded from load_experiment. 4679 If prediction is True, log and model files are not created. 4680 The enforce_split_parameters parameter is used to resolve conflicts 4681 between split file path and split parameters when they arise. 4682 4683 Parameters 4684 ---------- 4685 parameters : dict 4686 the parameters dictionary to update 4687 episode_name : str 4688 the name of the episode 4689 load_experiment : str, optional 4690 the name of the experiment to load from 4691 load_epoch : int, optional 4692 the epoch to load (by default the last one) 4693 load_strict : bool, default True 4694 if `True`, strict loading is enforced 4695 only_load_model : bool, default False 4696 if `True`, only the model is loaded 4697 continuing : bool, default False 4698 if `True`, continues from existing files 4699 enforce_split_parameters : bool, default False 4700 if `True`, split parameters are enforced 4701 4702 Returns 4703 ------- 4704 parameters : dict 4705 the updated parameters dictionary 4706 4707 """ 4708 pars = deepcopy(parameters) 4709 if episode_name == "_": 4710 self.remove_episode("_") 4711 log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt") 4712 model_save_path = os.path.join( 4713 self.project_path, "results", "model", episode_name 4714 ) 4715 if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)): 4716 raise ValueError( 4717 f"The {episode_name} episode name is already in use! Set force=True to overwrite." 4718 ) 4719 keys = ["val_frac", "test_frac", "partition_method"] 4720 if "len_segment" not in pars["general"] and "len_segment" in pars["data"]: 4721 pars["general"]["len_segment"] = pars["data"]["len_segment"] 4722 if "overlap" not in pars["general"] and "overlap" in pars["data"]: 4723 pars["general"]["overlap"] = pars["data"]["overlap"] 4724 if "len_segment" in pars["data"]: 4725 pars["data"].pop("len_segment") 4726 if "overlap" in pars["data"]: 4727 pars["data"].pop("overlap") 4728 split_info = {k: pars["training"][k] for k in keys} 4729 split_info["only_load_annotated"] = pars["general"]["only_load_annotated"] 4730 split_info["len_segment"] = pars["general"]["len_segment"] 4731 split_info["overlap"] = pars["general"]["overlap"] 4732 pars["training"]["log_file"] = log 4733 if not os.path.exists(model_save_path): 4734 os.mkdir(model_save_path) 4735 pars["training"]["model_save_path"] = model_save_path 4736 if load_experiment is not None: 4737 if load_experiment not in self._episodes().data.index: 4738 raise ValueError(f"The {load_experiment} episode does not exist!") 4739 old_episode = self._episode(load_experiment) 4740 old_file = old_episode.split_file() 4741 old_info = self._split_info_from_filename(old_file) 4742 if len(old_info) == 0: 4743 old_info = old_episode.split_info() 4744 if enforce_split_parameters: 4745 if split_info["partition_method"] != "file": 4746 pars["training"]["split_path"] = self._default_split_file( 4747 split_info 4748 ) 4749 else: 4750 equal = True 4751 if old_info["partition_method"] != split_info["partition_method"]: 4752 equal = False 4753 if old_info["partition_method"] != "file": 4754 if ( 4755 old_info["val_frac"] != split_info["val_frac"] 4756 or old_info["test_frac"] != split_info["test_frac"] 4757 ): 4758 equal = False 4759 if not continuing and not equal: 4760 warnings.warn( 4761 f"The partitioning parameters in the loaded experiment ({old_info}) " 4762 f"are not equal to the current partitioning parameters ({split_info}). " 4763 f"The current parameters are replaced." 4764 ) 4765 pars["training"]["split_path"] = old_file 4766 for k, v in old_info.items(): 4767 pars["training"][k] = v 4768 pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch) 4769 pars["training"]["load_strict"] = load_strict 4770 else: 4771 pars["training"]["checkpoint_path"] = None 4772 if pars["training"]["partition_method"] == "file": 4773 if ( 4774 "split_path" not in pars["training"] 4775 or pars["training"]["split_path"] is None 4776 ): 4777 raise ValueError( 4778 "The partition_method parameter is set to file but the " 4779 "split_path parameter is not set!" 4780 ) 4781 elif not os.path.exists(pars["training"]["split_path"]): 4782 raise ValueError( 4783 f'The {pars["training"]["split_path"]} split file does not exist' 4784 ) 4785 else: 4786 pars["training"]["split_path"] = self._default_split_file(split_info) 4787 pars["training"]["only_load_model"] = only_load_model 4788 pars["data"]["saved_data_path"] = None 4789 pars["data"]["feature_save_path"] = None 4790 pars_data_copy = self._get_data_pars(pars) 4791 saved_data_name = self._saved_datasets().find_name(pars_data_copy) 4792 if saved_data_name is not None: 4793 pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name) 4794 pars["data"]["feature_save_path"] = self._dataset_store_path( 4795 saved_data_name 4796 ).split(".")[0] 4797 else: 4798 dataset_path = self._dataset_store_path(episode_name) 4799 if os.path.exists(dataset_path): 4800 name, ext = dataset_path.split(".") 4801 i = 0 4802 while os.path.exists(f"{name}_{i}.{ext}"): 4803 i += 1 4804 dataset_path = f"{name}_{i}.{ext}" 4805 pars["data"]["saved_data_path"] = dataset_path 4806 pars["data"]["feature_save_path"] = dataset_path.split(".")[0] 4807 split_split = pars["training"]["partition_method"].split(":") 4808 random = True 4809 for partition_method in options.partition_methods["fixed"]: 4810 method_split = partition_method.split(":") 4811 if len(split_split) != len(method_split): 4812 continue 4813 equal = True 4814 for x, y in zip(split_split, method_split): 4815 if y.startswith("{"): 4816 continue 4817 if x != y: 4818 equal = False 4819 break 4820 if equal: 4821 random = False 4822 break 4823 if random and os.path.exists(pars["training"]["split_path"]): 4824 pars["training"]["partition_method"] = "file" 4825 pars["general"]["save_dataset"] = True 4826 # Check len_segment for c2f models 4827 if pars["general"]["model_name"].startswith("c2f"): 4828 if int(pars["general"]["len_segment"]) < 512: 4829 raise ValueError( 4830 "The segment length should be higher than 512 when using one of the C2F models" 4831 ) 4832 return pars 4833 4834 def _get_data_pars(self, pars: Dict) -> Dict: 4835 """Get a complete description of the data from a general parameters dictionary. 4836 4837 Parameters 4838 ---------- 4839 pars : dict 4840 the general parameters dictionary 4841 4842 Returns 4843 ------- 4844 pars_data : dict 4845 the complete data parameters dictionary 4846 4847 """ 4848 pars_data_copy = deepcopy(pars["data"]) 4849 for par in [ 4850 "only_load_annotated", 4851 "exclusive", 4852 "feature_extraction", 4853 "ignored_clips", 4854 "len_segment", 4855 "overlap", 4856 ]: 4857 pars_data_copy[par] = pars["general"].get(par, None) 4858 pars_data_copy.update(pars["features"]) 4859 return pars_data_copy 4860 4861 def _make_al_points_from_suggestions( 4862 self, 4863 suggestions_name: str, 4864 task: TaskDispatcher, 4865 predicted_classes: Dict, 4866 background_threshold: Optional[float], 4867 visibility_min_score: float, 4868 visibility_min_frac: float, 4869 num_behaviors: int, 4870 ): 4871 valleys = [] 4872 if background_threshold is not None: 4873 for i in range(num_behaviors): 4874 print(f"generating background for behavior {i}...") 4875 valleys.append( 4876 task.dataset("train").find_valleys( 4877 predicted_classes, 4878 threshold=background_threshold, 4879 visibility_min_score=visibility_min_score, 4880 visibility_min_frac=visibility_min_frac, 4881 main_class=i, 4882 low=True, 4883 cut_annotated=True, 4884 min_frames=1, 4885 ) 4886 ) 4887 valleys = task.dataset("train").valleys_intersection(valleys) 4888 folder = os.path.join( 4889 self.project_path, "results", "suggestions", suggestions_name 4890 ) 4891 os.makedirs(os.path.dirname(folder), exist_ok=True) 4892 res = {} 4893 for file in os.listdir(folder): 4894 video_id = file.split("_suggestion.p")[0] 4895 res[video_id] = [] 4896 with open(os.path.join(folder, file), "rb") as f: 4897 data = pickle.load(f) 4898 for clip_id, ind_list in zip(data[2], data[3]): 4899 max_len = max( 4900 [ 4901 max([x[1] for x in cat_list]) if len(cat_list) > 0 else 0 4902 for cat_list in ind_list 4903 ] 4904 ) 4905 if max_len == 0: 4906 continue 4907 arr = torch.zeros(max_len) 4908 for cat_list in ind_list: 4909 for start, end, amb in cat_list: 4910 arr[start:end] = 1 4911 if video_id in valleys: 4912 for start, end, clip in valleys[video_id]: 4913 if clip == clip_id: 4914 arr[start:end] = 1 4915 output, indices, counts = torch.unique_consecutive( 4916 arr > 0, return_inverse=True, return_counts=True 4917 ) 4918 long_indices = torch.where(output)[0] 4919 res[video_id] += [ 4920 ( 4921 (indices == i).nonzero(as_tuple=True)[0][0].item(), 4922 (indices == i).nonzero(as_tuple=True)[0][-1].item(), 4923 clip_id, 4924 ) 4925 for i in long_indices 4926 ] 4927 return res 4928 4929 def _make_al_points( 4930 self, 4931 task: TaskDispatcher, 4932 predicted_error: torch.Tensor, 4933 predicted_classes: torch.Tensor, 4934 exclude_classes: List, 4935 exclude_threshold: List, 4936 exclude_threshold_diff: List, 4937 exclude_hysteresis: List, 4938 include_classes: List, 4939 include_threshold: List, 4940 include_threshold_diff: List, 4941 include_hysteresis: List, 4942 error_episode: str = None, 4943 error_class: str = None, 4944 suggestion_episodes: List = None, 4945 error_threshold: float = 0.5, 4946 error_threshold_diff: float = 0.1, 4947 error_hysteresis: bool = False, 4948 min_frames_al: int = 30, 4949 visibility_min_score: float = 5, 4950 visibility_min_frac: float = 0.7, 4951 ) -> Dict: 4952 """Generate an active learning file.""" 4953 if len(exclude_classes) > 0 or len(include_classes) > 0: 4954 valleys = [] 4955 included = None 4956 excluded = None 4957 for class_name, thr, thr_diff, hysteresis in zip( 4958 exclude_classes, 4959 exclude_threshold, 4960 exclude_threshold_diff, 4961 exclude_hysteresis, 4962 ): 4963 episode = self._episodes().get_runs(suggestion_episodes[0])[0] 4964 class_index = self._episode(episode).get_class_ind(class_name) 4965 valleys.append( 4966 task.dataset("train").find_valleys( 4967 predicted_classes, 4968 predicted_error=predicted_error, 4969 min_frames=min_frames_al, 4970 threshold=thr, 4971 visibility_min_score=visibility_min_score, 4972 visibility_min_frac=visibility_min_frac, 4973 error_threshold=error_threshold, 4974 main_class=class_index, 4975 low=True, 4976 threshold_diff=thr_diff, 4977 min_frames_error=min_frames_al, 4978 hysteresis=hysteresis, 4979 ) 4980 ) 4981 if len(valleys) > 0: 4982 included = task.dataset("train").valleys_union(valleys) 4983 valleys = [] 4984 for class_name, thr, thr_diff, hysteresis in zip( 4985 include_classes, 4986 include_threshold, 4987 include_threshold_diff, 4988 include_hysteresis, 4989 ): 4990 episode = self._episodes().get_runs(suggestion_episodes[0])[0] 4991 class_index = self._episode(episode).get_class_ind(class_name) 4992 valleys.append( 4993 task.dataset("train").find_valleys( 4994 predicted_classes, 4995 predicted_error=predicted_error, 4996 min_frames=min_frames_al, 4997 threshold=thr, 4998 visibility_min_score=visibility_min_score, 4999 visibility_min_frac=visibility_min_frac, 5000 error_threshold=error_threshold, 5001 main_class=class_index, 5002 low=False, 5003 threshold_diff=thr_diff, 5004 min_frames_error=min_frames_al, 5005 hysteresis=hysteresis, 5006 ) 5007 ) 5008 if len(valleys) > 0: 5009 excluded = task.dataset("train").valleys_union(valleys) 5010 al_points = task.dataset("train").valleys_intersection([included, excluded]) 5011 else: 5012 class_index = self._episode(error_episode).get_class_ind(error_class) 5013 print("generating active learning intervals...") 5014 al_points = task.dataset("train").find_valleys( 5015 predicted_error, 5016 min_frames=min_frames_al, 5017 threshold=error_threshold, 5018 visibility_min_score=visibility_min_score, 5019 visibility_min_frac=visibility_min_frac, 5020 main_class=class_index, 5021 low=True, 5022 threshold_diff=error_threshold_diff, 5023 min_frames_error=min_frames_al, 5024 hysteresis=error_hysteresis, 5025 ) 5026 for v_id in al_points: 5027 clip_dict = defaultdict(lambda: []) 5028 res = [] 5029 for x in al_points[v_id]: 5030 clip_dict[x[-1]].append(x) 5031 for clip_id in clip_dict: 5032 clip_dict[clip_id] = sorted(clip_dict[clip_id]) 5033 i = 0 5034 j = 1 5035 while j < len(clip_dict[clip_id]): 5036 end = clip_dict[clip_id][i][1] 5037 start = clip_dict[clip_id][j][0] 5038 if start - end < 30: 5039 clip_dict[clip_id][i][1] = clip_dict[clip_id][j][1] 5040 else: 5041 res.append(clip_dict[clip_id][i]) 5042 i = j 5043 j += 1 5044 res.append(clip_dict[clip_id][i]) 5045 al_points[v_id] = sorted(res) 5046 return al_points 5047 5048 def _make_suggestions( 5049 self, 5050 task: TaskDispatcher, 5051 predicted_error: torch.Tensor, 5052 predicted_classes: torch.Tensor, 5053 suggestion_threshold: List, 5054 suggestion_threshold_diff: List, 5055 suggestion_hysteresis: List, 5056 suggestion_episodes: List = None, 5057 suggestion_classes: List = None, 5058 error_threshold: float = 0.5, 5059 min_frames_suggestion: int = 3, 5060 min_frames_al: int = 30, 5061 visibility_min_score: float = 0, 5062 visibility_min_frac: float = 0.7, 5063 cut_annotated: bool = False, 5064 ) -> Dict: 5065 """Make a suggestions dictionary.""" 5066 suggestions = defaultdict(lambda: {}) 5067 for class_name, thr, thr_diff, hysteresis in zip( 5068 suggestion_classes, 5069 suggestion_threshold, 5070 suggestion_threshold_diff, 5071 suggestion_hysteresis, 5072 ): 5073 episode = self._episodes().get_runs(suggestion_episodes[0])[0] 5074 class_index = self._episode(episode).get_class_ind(class_name) 5075 print(f"generating suggestions for {class_name}...") 5076 found = task.dataset("train").find_valleys( 5077 predicted_classes, 5078 smooth_interval=2, 5079 predicted_error=predicted_error, 5080 min_frames=min_frames_suggestion, 5081 threshold=thr, 5082 visibility_min_score=visibility_min_score, 5083 visibility_min_frac=visibility_min_frac, 5084 error_threshold=error_threshold, 5085 main_class=class_index, 5086 low=False, 5087 threshold_diff=thr_diff, 5088 min_frames_error=min_frames_al, 5089 hysteresis=hysteresis, 5090 cut_annotated=cut_annotated, 5091 ) 5092 for v_id in found: 5093 suggestions[v_id][class_name] = found[v_id] 5094 suggestions = dict(suggestions) 5095 return suggestions 5096 5097 def count_classes( 5098 self, 5099 load_episode: str = None, 5100 parameters_update: Dict = None, 5101 remove_saved_features: bool = False, 5102 bouts: bool = True, 5103 ) -> Dict: 5104 """Get a dictionary of class counts in different modes. 5105 5106 Parameters 5107 ---------- 5108 load_episode : str, optional 5109 the episode settings to load 5110 parameters_update : dict, optional 5111 a dictionary of parameter updates (only for "data" and "general" categories) 5112 remove_saved_features : bool, default False 5113 if `True`, the dataset that is used for computation is then deleted 5114 bouts : bool, default False 5115 if `True`, instead of frame counts segment counts are returned 5116 5117 Returns 5118 ------- 5119 class_counts : dict 5120 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 5121 class names and values are class counts (in frames) 5122 5123 """ 5124 if load_episode is None: 5125 task, parameters = self._make_task_training( 5126 episode_name="_", parameters_update=parameters_update, throwaway=True 5127 ) 5128 else: 5129 task, parameters, _ = self._make_task_prediction( 5130 "_", 5131 load_episode=load_episode, 5132 parameters_update=parameters_update, 5133 ) 5134 class_counts = task.count_classes(bouts=bouts) 5135 behaviors = task.behaviors_dict() 5136 class_counts = { 5137 kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()} 5138 for kk, vv in class_counts.items() 5139 } 5140 if remove_saved_features: 5141 self._remove_stores(parameters) 5142 return class_counts 5143 5144 def plot_class_distribution( 5145 self, 5146 parameters_update: Dict = None, 5147 frame_cutoff: int = 1, 5148 bout_cutoff: int = 1, 5149 print_full: bool = False, 5150 remove_saved_features: bool = False, 5151 save: str = None, 5152 ) -> None: 5153 """Make a class distribution plot. 5154 5155 You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset 5156 is created or loaded for the computation with the default parameters). 5157 5158 Parameters 5159 ---------- 5160 parameters_update : dict, optional 5161 a dictionary of parameter updates (only for "data" and "general" categories) 5162 frame_cutoff : int, default 1 5163 the minimum number of frames for a segment to be considered 5164 bout_cutoff : int, default 1 5165 the minimum number of bouts for a class to be considered 5166 print_full : bool, default False 5167 if `True`, the full class distribution is printed 5168 remove_saved_features : bool, default False 5169 if `True`, the dataset that is used for computation is then deleted 5170 5171 """ 5172 task, parameters = self._make_task_training( 5173 episode_name="_", parameters_update=parameters_update, throwaway=True 5174 ) 5175 cutoff = {True: bout_cutoff, False: frame_cutoff} 5176 for bouts in [True, False]: 5177 class_counts = task.count_classes(bouts=bouts) 5178 if print_full: 5179 print("Bouts:" if bouts else "Frames:") 5180 for k, v in class_counts.items(): 5181 if sum(v.values()) != 0: 5182 print(f" {k}:") 5183 values, keys = zip( 5184 *[ 5185 x 5186 for x in sorted(zip(v.values(), v.keys()), reverse=True) 5187 if x[-1] != -100 5188 ] 5189 ) 5190 for kk, vv in zip(keys, values): 5191 print(f" {task.behaviors_dict()[kk]}: {vv}") 5192 class_counts = { 5193 kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]} 5194 for kk, vv in class_counts.items() 5195 } 5196 for key, d in class_counts.items(): 5197 if sum(d.values()) != 0: 5198 values, keys = zip( 5199 *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100] 5200 ) 5201 keys = [task.behaviors_dict()[x] for x in keys] 5202 plt.bar(keys, values) 5203 plt.title(key) 5204 plt.xticks(rotation=45, ha="right") 5205 if bouts: 5206 plt.ylabel("bouts") 5207 else: 5208 plt.ylabel("frames") 5209 plt.tight_layout() 5210 5211 if save is None: 5212 plt.savefig(save) 5213 plt.close() 5214 else: 5215 plt.show() 5216 if remove_saved_features: 5217 self._remove_stores(parameters) 5218 5219 def _generate_mask( 5220 self, 5221 mask_name: str, 5222 perc_annotated: float = 0.1, 5223 parameters_update: Dict = None, 5224 remove_saved_features: bool = False, 5225 ) -> None: 5226 """Generate a real_lens for active learning simulation. 5227 5228 Parameters 5229 ---------- 5230 mask_name : str 5231 the name of the real_lens 5232 perc_annotated : float, default 0.1 5233 a 5234 5235 """ 5236 print(f"GENERATING {mask_name}") 5237 task, parameters = self._make_task_training( 5238 f"_{mask_name}", parameters_update=parameters_update, throwaway=True 5239 ) 5240 val_intervals, val_ids = task.dataset("val").get_intervals() # 1 5241 unannotated_intervals = task.dataset("train").get_unannotated_intervals() # 2 5242 unannotated_intervals = task.dataset("val").get_unannotated_intervals( 5243 first_intervals=unannotated_intervals 5244 ) 5245 ids = task.dataset("train").get_ids() 5246 mask = {video_id: {} for video_id in ids} 5247 total_all = 0 5248 total_masked = 0 5249 for video_id, clip_ids in ids.items(): 5250 for clip_id in clip_ids: 5251 frames = np.ones(task.dataset("train").get_len(video_id, clip_id)) 5252 if clip_id in val_intervals[video_id]: 5253 for start, end in val_intervals[video_id][clip_id]: 5254 frames[start:end] = 0 5255 if clip_id in unannotated_intervals[video_id]: 5256 for start, end in unannotated_intervals[video_id][clip_id]: 5257 frames[start:end] = 0 5258 annotated = np.where(frames)[0] 5259 total_all += len(annotated) 5260 masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :] 5261 total_masked += len(masked) 5262 mask[video_id][clip_id] = self._get_intervals(masked) 5263 file = { 5264 "masked": mask, 5265 "val_intervals": val_intervals, 5266 "val_ids": val_ids, 5267 "unannotated": unannotated_intervals, 5268 } 5269 self._save_mask(file, mask_name) 5270 if remove_saved_features: 5271 self._remove_stores(parameters) 5272 print("\n") 5273 # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames') 5274 5275 def _get_intervals(self, frame_indices: np.ndarray): 5276 """Get a list of intervals from a list of frame indices. 5277 5278 Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`. 5279 5280 Parameters 5281 ---------- 5282 frame_indices : np.ndarray 5283 a list of frame indices 5284 5285 Returns 5286 ------- 5287 intervals : list 5288 a list of interval boundaries 5289 5290 """ 5291 masked_intervals = [] 5292 if len(frame_indices) > 0: 5293 breaks = np.where(np.diff(frame_indices) != 1)[0] 5294 start = frame_indices[0] 5295 for k in breaks: 5296 masked_intervals.append([start, frame_indices[k] + 1]) 5297 start = frame_indices[k + 1] 5298 masked_intervals.append([start, frame_indices[-1] + 1]) 5299 return masked_intervals 5300 5301 def _update_mask_with_uncertainty( 5302 self, 5303 mask_name: str, 5304 episode_name: Union[str, None], 5305 classes: List, 5306 load_epoch: int = None, 5307 n_frames: int = 10000, 5308 method: str = "least_confidence", 5309 min_length: int = 30, 5310 augment_n: int = 0, 5311 parameters_update: Dict = None, 5312 ): 5313 """Update real_lens with frame-wise uncertainty scores for active learning. 5314 5315 Parameters 5316 ---------- 5317 mask_name : str 5318 the name of the real_lens 5319 episode_name : str 5320 the name of the episode to load 5321 classes : list 5322 a list of class names or indices; their uncertainty scores will be computed separately and stacked 5323 load_epoch : int, optional 5324 the epoch to load (by default last; if this epoch is not saved the closest checkpoint is chosen) 5325 n_frames : int, default 10000 5326 the number of frames to "annotate" 5327 method : {"least_confidence", "entropy"} 5328 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 5329 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 5330 min_length : int 5331 the minimum length (in frames) of the annotated intervals 5332 augment_n : int, default 0 5333 the number of augmentations to average over 5334 parameters_update : dict, optional 5335 the dictionary used to update the parameters from the config 5336 5337 Returns 5338 ------- 5339 score_dicts : dict 5340 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 5341 are score tensors 5342 5343 """ 5344 print(f"UPDATING {mask_name}") 5345 task, parameters, _ = self._make_task_prediction( 5346 prediction_name=mask_name, 5347 load_episode=episode_name, 5348 parameters_update=parameters_update, 5349 load_epoch=load_epoch, 5350 mode="train", 5351 ) 5352 score_tensors = task.generate_uncertainty_score(classes, augment_n, method) 5353 self._update_mask(task, mask_name, score_tensors, n_frames, min_length) 5354 print("\n") 5355 5356 def _update_mask_with_BALD( 5357 self, 5358 mask_name: str, 5359 episode_name: str, 5360 classes: List, 5361 load_epoch: int = None, 5362 augment_n: int = 0, 5363 n_frames: int = 10000, 5364 num_models: int = 10, 5365 kernel_size: int = 11, 5366 min_length: int = 30, 5367 parameters_update: Dict = None, 5368 ): 5369 """Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning. 5370 5371 Parameters 5372 ---------- 5373 mask_name : str 5374 the name of the real_lens 5375 episode_name : str 5376 the name of the episode to load 5377 classes : list 5378 a list of class names or indices; their uncertainty scores will be computed separately and stacked 5379 load_epoch : int, optional 5380 the epoch to load (by default last) 5381 augment_n : int, default 0 5382 the number of augmentations to average over 5383 n_frames : int, default 10000 5384 the number of frames to "annotate" 5385 num_models : int, default 10 5386 the number of dropout masks to apply 5387 kernel_size : int, default 11 5388 the size of the smoothing gaussian kernel 5389 min_length : int 5390 the minimum length (in frames) of the annotated intervals 5391 parameters_update : dict, optional 5392 the dictionary used to update the parameters from the config 5393 5394 Returns 5395 ------- 5396 score_dicts : dict 5397 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 5398 are score tensors 5399 5400 """ 5401 print(f"UPDATING {mask_name}") 5402 task, parameters, mode = self._make_task_prediction( 5403 mask_name, 5404 load_episode=episode_name, 5405 parameters_update=parameters_update, 5406 load_epoch=load_epoch, 5407 ) 5408 score_tensors = task.generate_bald_score( 5409 classes, augment_n, num_models, kernel_size 5410 ) 5411 self._update_mask(task, mask_name, score_tensors, n_frames, min_length) 5412 print("\n") 5413 5414 def _suggest_intervals( 5415 self, 5416 dataset: BehaviorDataset, 5417 score_tensors: Dict, 5418 n_frames: int, 5419 min_length: int, 5420 ) -> Dict: 5421 """Suggest intervals with highest score of total length `n_frames`. 5422 5423 Parameters 5424 ---------- 5425 dataset : BehaviorDataset 5426 the dataset 5427 score_tensors : dict 5428 a dictionary where keys are clip ids and values are framewise score tensors 5429 n_frames : int 5430 the number of frames to "annotate" 5431 min_length : int 5432 minimum length of suggested intervals 5433 5434 Returns 5435 ------- 5436 active_learning_intervals : Dict 5437 active learning dictionary with suggested intervals 5438 5439 """ 5440 video_intervals, _ = dataset.get_intervals() 5441 taken = { 5442 video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys() 5443 } 5444 annotated = dataset.get_annotated_intervals() 5445 for video_id in video_intervals: 5446 for clip_id in video_intervals[video_id]: 5447 taken[video_id][clip_id] = torch.zeros( 5448 dataset.get_len(video_id, clip_id) 5449 ) 5450 if video_id in annotated and clip_id in annotated[video_id]: 5451 for start, end in annotated[video_id][clip_id]: 5452 score_tensors[video_id][clip_id][:, start:end] = -10 5453 taken[video_id][clip_id][int(start) : int(end)] = 1 5454 n_frames = ( 5455 sum([(vv == 1).sum() for v in taken.values() for vv in v.values()]) 5456 + n_frames 5457 ) 5458 factor = 1 5459 threshold_start = float( 5460 torch.mean( 5461 torch.tensor( 5462 [ 5463 torch.mean( 5464 torch.tensor([torch.mean(y[y > 0]) for y in x.values()]) 5465 ) 5466 for x in score_tensors.values() 5467 ] 5468 ) 5469 ) 5470 ) 5471 while ( 5472 sum([(vv == 1).sum() for v in taken.values() for vv in v.values()]) 5473 < n_frames 5474 ): 5475 threshold = threshold_start * factor 5476 intervals = [] 5477 interval_scores = [] 5478 key1 = list(score_tensors.keys())[0] 5479 key2 = list(score_tensors[key1].keys())[0] 5480 num_scores = score_tensors[key1][key2].shape[0] 5481 for i in range(num_scores): 5482 v_dict = dataset.find_valleys( 5483 predicted=score_tensors, 5484 threshold=threshold, 5485 min_frames=min_length, 5486 main_class=i, 5487 low=False, 5488 ) 5489 for v_id, interval_list in v_dict.items(): 5490 intervals += [x + [v_id] for x in interval_list] 5491 interval_scores += [ 5492 float(torch.mean(score_tensors[v_id][clip_id][i, start:end])) 5493 for start, end, clip_id in interval_list 5494 ] 5495 intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]] 5496 i = 0 5497 while sum( 5498 [(vv == 1).sum() for v in taken.values() for vv in v.values()] 5499 ) < n_frames and i < len(intervals): 5500 start, end, clip_id, video_id = intervals[i] 5501 i += 1 5502 taken[video_id][clip_id][int(start) : int(end)] = 1 5503 factor *= 0.9 5504 if factor < 0.05: 5505 warnings.warn(f"Could not find enough frames!") 5506 break 5507 active_learning_intervals = {video_id: [] for video_id in video_intervals} 5508 for video_id in taken: 5509 for clip_id in taken[video_id]: 5510 if video_id in annotated and clip_id in annotated[video_id]: 5511 for start, end in annotated[video_id][clip_id]: 5512 taken[video_id][clip_id][int(start) : int(end)] = 0 5513 if (taken[video_id][clip_id] == 1).sum() == 0: 5514 continue 5515 indices = np.where(taken[video_id][clip_id].numpy())[0] 5516 boundaries = self._get_intervals(indices) 5517 active_learning_intervals[video_id] += [ 5518 [start, end, clip_id] for start, end in boundaries 5519 ] 5520 return active_learning_intervals 5521 5522 def _update_mask( 5523 self, 5524 task: TaskDispatcher, 5525 mask_name: str, 5526 score_tensors: Dict, 5527 n_frames: int, 5528 min_length: int, 5529 ) -> None: 5530 """Update the real_lens with intervals with the highest score of total length `n_frames`. 5531 5532 Parameters 5533 ---------- 5534 task : TaskDispatcher 5535 the task dispatcher object 5536 mask_name : str 5537 the name of the real_lens 5538 score_tensors : dict 5539 a dictionary where keys are clip ids and values are framewise score tensors 5540 n_frames : int 5541 the number of frames to "annotate" 5542 min_length : int 5543 the minimum length of the annotated intervals 5544 5545 """ 5546 mask = self._load_mask(mask_name) 5547 video_intervals, _ = task.dataset("train").get_intervals() 5548 masked = { 5549 video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys() 5550 } 5551 total_masked = 0 5552 total_all = 0 5553 for video_id in video_intervals: 5554 for clip_id in video_intervals[video_id]: 5555 masked[video_id][clip_id] = torch.zeros( 5556 task.dataset("train").get_len(video_id, clip_id) 5557 ) 5558 if ( 5559 video_id in mask["unannotated"] 5560 and clip_id in mask["unannotated"][video_id] 5561 ): 5562 for start, end in mask["unannotated"][video_id][clip_id]: 5563 score_tensors[video_id][clip_id][:, start:end] = -10 5564 masked[video_id][clip_id][int(start) : int(end)] = 1 5565 if ( 5566 video_id in mask["val_intervals"] 5567 and clip_id in mask["val_intervals"][video_id] 5568 ): 5569 for start, end in mask["val_intervals"][video_id][clip_id]: 5570 score_tensors[video_id][clip_id][:, start:end] = -10 5571 masked[video_id][clip_id][int(start) : int(end)] = 1 5572 total_all += torch.sum(masked[video_id][clip_id] == 0) 5573 if video_id in mask["masked"] and clip_id in mask["masked"][video_id]: 5574 # print(f'{real_lens["masked"][video_id][clip_id]=}') 5575 for start, end in mask["masked"][video_id][clip_id]: 5576 masked[video_id][clip_id][int(start) : int(end)] = 1 5577 total_masked += end - start 5578 old_n_frames = sum( 5579 [(vv == 0).sum() for v in masked.values() for vv in v.values()] 5580 ) 5581 n_frames = old_n_frames + n_frames 5582 factor = 1 5583 while ( 5584 sum([(vv == 0).sum() for v in masked.values() for vv in v.values()]) 5585 < n_frames 5586 ): 5587 threshold = float( 5588 torch.mean( 5589 torch.tensor( 5590 [ 5591 torch.mean( 5592 torch.tensor([torch.mean(y[y > 0]) for y in x.values()]) 5593 ) 5594 for x in score_tensors.values() 5595 ] 5596 ) 5597 ) 5598 ) 5599 threshold = threshold * factor 5600 intervals = [] 5601 interval_scores = [] 5602 key1 = list(score_tensors.keys())[0] 5603 key2 = list(score_tensors[key1].keys())[0] 5604 num_scores = score_tensors[key1][key2].shape[0] 5605 for i in range(num_scores): 5606 v_dict = task.dataset("train").find_valleys( 5607 predicted=score_tensors, 5608 threshold=threshold, 5609 min_frames=min_length, 5610 main_class=i, 5611 low=False, 5612 ) 5613 for v_id, interval_list in v_dict.items(): 5614 intervals += [x + [v_id] for x in interval_list] 5615 interval_scores += [ 5616 float(torch.mean(score_tensors[v_id][clip_id][i, start:end])) 5617 for start, end, clip_id in interval_list 5618 ] 5619 intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]] 5620 i = 0 5621 while sum( 5622 [(vv == 0).sum() for v in masked.values() for vv in v.values()] 5623 ) < n_frames and i < len(intervals): 5624 start, end, clip_id, video_id = intervals[i] 5625 i += 1 5626 masked[video_id][clip_id][int(start) : int(end)] = 0 5627 factor *= 0.9 5628 if factor < 0.05: 5629 warnings.warn(f"Could not find enough frames!") 5630 break 5631 mask["masked"] = {video_id: {} for video_id in video_intervals} 5632 total_masked_new = 0 5633 for video_id in masked: 5634 for clip_id in masked[video_id]: 5635 if ( 5636 video_id in mask["unannotated"] 5637 and clip_id in mask["unannotated"][video_id] 5638 ): 5639 for start, end in mask["unannotated"][video_id][clip_id]: 5640 masked[video_id][clip_id][int(start) : int(end)] = 0 5641 if ( 5642 video_id in mask["val_intervals"] 5643 and clip_id in mask["val_intervals"][video_id] 5644 ): 5645 for start, end in mask["val_intervals"][video_id][clip_id]: 5646 masked[video_id][clip_id][int(start) : int(end)] = 0 5647 indices = np.where(masked[video_id][clip_id].numpy())[0] 5648 mask["masked"][video_id][clip_id] = self._get_intervals(indices) 5649 for video_id in mask["masked"]: 5650 for clip_id in mask["masked"][video_id]: 5651 for start, end in mask["masked"][video_id][clip_id]: 5652 total_masked_new += end - start 5653 self._save_mask(mask, mask_name) 5654 with open( 5655 os.path.join( 5656 self.project_path, "results", f"{mask_name}.txt", encoding="utf-8" 5657 ), 5658 "a", 5659 ) as f: 5660 f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n") 5661 print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}") 5662 5663 def _visualize_results_label( 5664 self, 5665 episode_name: str, 5666 label: str, 5667 load_epoch: int = None, 5668 parameters_update: Dict = None, 5669 add_legend: bool = True, 5670 ground_truth: bool = True, 5671 hide_axes: bool = False, 5672 width: float = 10, 5673 whole_video: bool = False, 5674 transparent: bool = False, 5675 num_plots: int = 1, 5676 smooth_interval: int = 0, 5677 ): 5678 other_path = os.path.join(self.project_path, "results", "other") 5679 if not os.path.exists(other_path): 5680 os.mkdir(other_path) 5681 if parameters_update is None: 5682 parameters_update = {} 5683 if "model" in parameters_update.keys(): 5684 raise ValueError("Cannot change model parameters after training!") 5685 task, parameters, _ = self._make_task_prediction( 5686 "_", 5687 load_episode=episode_name, 5688 parameters_update=parameters_update, 5689 load_epoch=load_epoch, 5690 mode="val", 5691 ) 5692 for i in range(num_plots): 5693 print(i) 5694 task._visualize_results_label( 5695 smooth_interval=smooth_interval, 5696 label=label, 5697 save_path=os.path.join( 5698 other_path, f"{episode_name}_prediction_{i}.jpg" 5699 ), 5700 add_legend=add_legend, 5701 ground_truth=ground_truth, 5702 hide_axes=hide_axes, 5703 whole_video=whole_video, 5704 transparent=transparent, 5705 dataset="val", 5706 width=width, 5707 title=str(i), 5708 ) 5709 5710 def plot_confusion_matrix( 5711 self, 5712 episode_name: str, 5713 load_epoch: int = None, 5714 parameters_update: Dict = None, 5715 metric: str = "recall", 5716 mode: str = "val", 5717 remove_saved_features: bool = False, 5718 save_path: str = None, 5719 cmap: str = "viridis", 5720 ) -> Tuple[ndarray, Iterable]: 5721 """Make a confusion matrix plot and return the data. 5722 5723 If the annotation is non-exclusive, only false positive labels are considered. 5724 5725 Parameters 5726 ---------- 5727 episode_name : str 5728 the name of the episode to load 5729 load_epoch : int, optional 5730 the index of the epoch to load (by default the last one is loaded) 5731 parameters_update : dict, optional 5732 a dictionary of parameter updates (only for "data" and "general" categories) 5733 metric : {"recall", "precision"} 5734 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 5735 into account, and if `type` is `"precision"`, only false negatives 5736 mode : {'val', 'all', 'test', 'train'} 5737 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 5738 remove_saved_features : bool, default False 5739 if `True`, the dataset that is used for computation is then deleted 5740 5741 Returns 5742 ------- 5743 confusion_matrix : np.ndarray 5744 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 5745 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 5746 `N_i` is the number of frames that have the i-th label in the ground truth 5747 classes : list 5748 a list of labels 5749 5750 """ 5751 task, parameters, mode = self._make_task_prediction( 5752 "_", 5753 load_episode=episode_name, 5754 load_epoch=load_epoch, 5755 parameters_update=parameters_update, 5756 mode=mode, 5757 ) 5758 dataset = task.dataset(mode) 5759 prediction = task.predict(dataset, raw_output=True) 5760 confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type) 5761 if remove_saved_features: 5762 self._remove_stores(parameters) 5763 fig, ax = plt.subplots(figsize=(len(classes), len(classes))) 5764 ax.imshow(confusion_matrix, cmap=cmap) 5765 # Show all ticks and label them with the respective list entries 5766 ax.set_xticks(np.arange(len(classes))) 5767 ax.set_xticklabels(classes) 5768 ax.set_yticks(np.arange(len(classes))) 5769 ax.set_yticklabels(classes) 5770 # Rotate the tick labels and set their alignment. 5771 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 5772 # Loop over data dimensions and create text annotations. 5773 for i in range(len(classes)): 5774 for j in range(len(classes)): 5775 ax.text( 5776 j, 5777 i, 5778 np.round(confusion_matrix[i, j], 2), 5779 ha="center", 5780 va="center", 5781 color="w", 5782 ) 5783 if metric is not None: 5784 ax.set_title(f"{metric} {episode_name}") 5785 else: 5786 ax.set_title(episode_name) 5787 fig.tight_layout() 5788 if save_path is None: 5789 plt.show() 5790 else: 5791 plt.savefig(save_path) 5792 plt.close() 5793 return confusion_matrix, classes 5794 5795 def _plot_ethograms_gt_pred( 5796 self, 5797 data_gt: dict, 5798 data_pred: dict, 5799 labels_gt: list, 5800 labels_pred: list, 5801 start: int = 0, 5802 end: int = -1, 5803 cmap_pred: str = "binary", 5804 cmap_gt: str = "binary", 5805 save: str = None, 5806 fontsize=22, 5807 time_mode="frames", 5808 fps: int = None, 5809 ) -> None: 5810 """Plot ethograms from start to end time (in frames), mode can be prediction or ground truth depending on the data format.""" 5811 # print(data.keys()) 5812 best_pred = ( 5813 data_pred[list(data_pred.keys())[0]].numpy() > 0.5 5814 ) # Threshold the predictions 5815 data_gt = binarize_data(data_gt, max_frame=end) 5816 5817 # Crop data to min length 5818 if end < 0: 5819 end = min(data_gt.shape[1], best_pred.shape[1]) 5820 data_gt = data_gt[:, :end] 5821 best_pred = best_pred[:, :end] 5822 5823 # Reorder behaviors 5824 ind_gt = [] 5825 ind_pred = [] 5826 labels_pred = [labels_pred[i] for i in range(len(labels_pred))] 5827 labels_pred = np.roll( 5828 labels_pred, 1 5829 ).tolist() 5830 check_gt = np.where(np.sum(data_gt, axis=1) > 0)[0] 5831 check_pred = np.where(np.sum(best_pred, axis=1) > 0)[0] 5832 for k, gt_beh in enumerate(labels_gt): 5833 if gt_beh in labels_pred: 5834 j = labels_pred.index(gt_beh) 5835 if not k in check_gt and not j in check_pred: 5836 continue 5837 ind_gt.append(labels_gt.index(gt_beh)) 5838 ind_pred.append(j) 5839 # Create label list 5840 labels = np.array(labels_gt)[ind_gt] 5841 assert (labels == np.array(labels_pred)[ind_pred]).all() 5842 5843 # # Create image 5844 image_pred = best_pred[ind_pred].astype(float) 5845 image_gt = data_gt[ind_gt] 5846 5847 f, axs = plt.subplots( 5848 len(labels), 1, figsize=(5 * len(labels), 15), sharex=True 5849 ) 5850 end = image_gt.shape[1] if end < 0 else end 5851 for i, (ax, label) in enumerate(zip(axs, labels)): 5852 5853 im1 = np.array([image_gt[i], np.ones_like(image_gt[i]) * (-1)]) 5854 im1 = np.ma.masked_array(im1, im1 < 0) 5855 5856 im2 = np.array([np.ones_like(image_pred[i]) * (-1), image_pred[i]]) 5857 im2 = np.ma.masked_array(im2, im2 < 0) 5858 5859 ax.imshow(im1, aspect="auto", cmap=cmap_gt, interpolation="nearest") 5860 ax.imshow(im2, aspect="auto", cmap=cmap_pred, interpolation="nearest") 5861 5862 ax.set_yticks(np.arange(2), ["GT", "Pred"], fontsize=fontsize) 5863 ax.tick_params(axis="x", labelsize=fontsize) 5864 ax.set_ylabel(label, fontsize=fontsize) 5865 if time_mode == "frames": 5866 ax.set_xlabel("Num Frames", fontsize=fontsize) 5867 elif time_mode == "seconds": 5868 assert not fps is None, "Please provide fps" 5869 ax.set_xlabel("Time (s)", fontsize=fontsize) 5870 ax.set_xticks( 5871 np.linspace(0, end, 10), 5872 np.linspace(0, end / fps, 10).astype(np.int32), 5873 ) 5874 5875 ax.set_xlim(start, end) 5876 5877 if save is None: 5878 plt.show() 5879 else: 5880 plt.savefig(save) 5881 plt.close() 5882 5883 def plot_ethograms( 5884 self, 5885 episode_name: str, 5886 prediction_name: str, 5887 start: int = 0, 5888 end: int = -1, 5889 save_path: str = None, 5890 cmap_pred: str = "binary", 5891 cmap_gt: str = "binary", 5892 fontsize: int = 22, 5893 time_mode: str = "frames", 5894 fps: int = None, 5895 ): 5896 """Plot ethograms from start to end time (in frames) for ground truth and prediction""" 5897 params = self._read_parameters(catch_blanks=False) 5898 parameters = self._get_data_pars( 5899 params, 5900 ) 5901 if not save_path is None: 5902 os.makedirs(save_path, exist_ok=True) 5903 gt_files = [ 5904 f for f in self.data_path if f.endswith(parameters["annotation_suffix"]) 5905 ] 5906 pred_path = os.path.join( 5907 self.project_path, "results", "predictions", prediction_name 5908 ) 5909 pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)] 5910 for pred_path in pred_paths: 5911 predictions = load_pickle(pred_path) 5912 behaviors = self.get_behavior_dictionary(episode_name) 5913 gt_filename = os.path.basename(pred_path).replace( 5914 "_".join(["_" + prediction_name, "prediction.pickle"]), 5915 parameters["annotation_suffix"], 5916 ) 5917 if os.path.exists(os.path.join(self.data_path, gt_filename)): 5918 gt_data = load_pickle(os.path.join(self.data_path, gt_filename)) 5919 5920 self._plot_ethograms_gt_pred( 5921 gt_data, 5922 predictions, 5923 gt_data[1], 5924 behaviors, 5925 start=start, 5926 end=end, 5927 save=os.path.join( 5928 save_path, 5929 os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred", 5930 ), 5931 cmap_pred=cmap_pred, 5932 cmap_gt=cmap_gt, 5933 fontsize=fontsize, 5934 time_mode=time_mode, 5935 fps=fps, 5936 ) 5937 else: 5938 print("GT file not found") 5939 5940 def _create_side_panel(self, height, width, labels_pred, preds, labels_gt, gt=None): 5941 """Create a side panel for video annotation display. 5942 5943 Parameters 5944 ---------- 5945 height : int 5946 the height of the panel 5947 width : int 5948 the width of the panel 5949 labels_pred : list 5950 the list of predicted behavior labels 5951 preds : array-like 5952 the prediction values for each behavior 5953 labels_gt : list 5954 the list of ground truth behavior labels 5955 gt : array-like, optional 5956 the ground truth values for each behavior 5957 5958 Returns 5959 ------- 5960 side_panel : np.ndarray 5961 the created side panel as an image array 5962 5963 """ 5964 side_panel = np.ones((height, int(width / 4), 3), dtype=np.uint8) * 255 5965 5966 beh_indices = np.where(preds)[0] 5967 for i, label in enumerate(labels_pred): 5968 color = (0, 0, 0) 5969 if i in beh_indices: 5970 color = (0, 255, 0) 5971 cv2.putText( 5972 side_panel, 5973 label, 5974 (10, 50 + 50 * i), 5975 cv2.FONT_HERSHEY_SIMPLEX, 5976 1, 5977 color, 5978 2, 5979 cv2.LINE_AA, 5980 ) 5981 if gt is not None: 5982 beh_indices_gt = np.where(gt)[0] 5983 for i, label in enumerate(labels_gt): 5984 color = (0, 0, 0) 5985 if i in beh_indices_gt: 5986 color = (0, 255, 0) 5987 cv2.putText( 5988 side_panel, 5989 label, 5990 (10, 50 + 50 * i + 80 * len(labels_pred)), 5991 cv2.FONT_HERSHEY_SIMPLEX, 5992 1, 5993 color, 5994 2, 5995 cv2.LINE_AA, 5996 ) 5997 return side_panel 5998 5999 def create_annotated_video( 6000 self, 6001 prediction_file_paths: list, 6002 video_file_paths: list, 6003 episode_name: str, # To get the list of behaviors 6004 ground_truth_file_paths: list = None, 6005 pred_thresh: float = 0.5, 6006 start: int = 0, 6007 end: int = -1, 6008 ): 6009 """Create a video with the predictions overlaid on the video""" 6010 for k, (pred_path, vid_path) in enumerate( 6011 zip(prediction_file_paths, video_file_paths) 6012 ): 6013 print("Generating video for :", os.path.basename(vid_path)) 6014 predictions = load_pickle(pred_path) 6015 best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh 6016 behaviors = self.get_behavior_dictionary(episode_name) 6017 # Load video 6018 labels_pred = [behaviors[i] for i in range(len(behaviors))] 6019 labels_pred = np.roll( 6020 labels_pred, 1 6021 ).tolist() 6022 6023 gt_data = None 6024 if ground_truth_file_paths is not None: 6025 gt_data = load_pickle(ground_truth_file_paths[k]) 6026 labels_gt = gt_data[1] 6027 gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1]) 6028 6029 cap = cv2.VideoCapture(vid_path) 6030 cap.set(cv2.CAP_PROP_POS_FRAMES, start) 6031 end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end 6032 fps = cap.get(cv2.CAP_PROP_FPS) 6033 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 6034 height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 6035 fourcc = cv2.VideoWriter_fourcc(*"mp4v") 6036 out = cv2.VideoWriter( 6037 os.path.join( 6038 os.path.dirname(vid_path), 6039 os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4", 6040 ), 6041 fourcc, 6042 fps, 6043 # (width + int(width/4) , height), 6044 (600, 300), 6045 ) 6046 count = 0 6047 bar = tqdm(total=end - start) 6048 while cap.isOpened(): 6049 ret, frame = cap.read() 6050 if not ret: 6051 break 6052 6053 side_panel = self._create_side_panel( 6054 height, 6055 width, 6056 labels_pred, 6057 best_pred[:, count], 6058 labels_gt, 6059 gt_data[:, count], 6060 ) 6061 frame = np.concatenate((frame, side_panel), axis=1) 6062 frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25) 6063 out.write(frame) 6064 count += 1 6065 bar.update(1) 6066 6067 if count > end: 6068 break 6069 6070 cap.release() 6071 out.release() 6072 cv2.destroyAllWindows() 6073 6074 def plot_predictions( 6075 self, 6076 episode_name: str, 6077 load_epoch: int = None, 6078 parameters_update: Dict = None, 6079 add_legend: bool = True, 6080 ground_truth: bool = True, 6081 colormap: str = "dlc2action", 6082 hide_axes: bool = False, 6083 min_classes: int = 1, 6084 width: float = 10, 6085 whole_video: bool = False, 6086 transparent: bool = False, 6087 drop_classes: Set = None, 6088 search_classes: Set = None, 6089 num_plots: int = 1, 6090 remove_saved_features: bool = False, 6091 smooth_interval_prediction: int = 0, 6092 data_path: str = None, 6093 file_paths: Set = None, 6094 mode: str = "val", 6095 font_size: float = None, 6096 window_size: int = 400, 6097 ) -> None: 6098 """Visualize random predictions. 6099 6100 Parameters 6101 ---------- 6102 episode_name : str 6103 the name of the episode to load 6104 load_epoch : int, optional 6105 the epoch to load (by default last) 6106 parameters_update : dict, optional 6107 parameter update dictionary 6108 add_legend : bool, default True 6109 if True, legend will be added to the plot 6110 ground_truth : bool, default True 6111 if True, ground truth will be added to the plot 6112 colormap : str, default 'Accent' 6113 the `matplotlib` colormap to use 6114 hide_axes : bool, default True 6115 if `True`, the axes will be hidden on the plot 6116 min_classes : int, default 1 6117 the minimum number of classes in a displayed interval 6118 width : float, default 10 6119 the width of the plot 6120 whole_video : bool, default False 6121 if `True`, whole videos are plotted instead of segments 6122 transparent : bool, default False 6123 if `True`, the background on the plot is transparent 6124 drop_classes : set, optional 6125 a set of class names to not be displayed 6126 search_classes : set, optional 6127 if given, only intervals where at least one of the classes is in ground truth will be shown 6128 num_plots : int, default 1 6129 the number of plots to make 6130 remove_saved_features : bool, default False 6131 if `True`, the dataset will be deleted after computation 6132 smooth_interval_prediction : int, default 0 6133 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) 6134 data_path : str, optional 6135 the data path to run the prediction for 6136 file_paths : set, optional 6137 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 6138 for 6139 mode : {'all', 'test', 'val', 'train'} 6140 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 6141 6142 """ 6143 plot_path = os.path.join(self.project_path, "results", "plots") 6144 task, parameters, mode = self._make_task_prediction( 6145 "_", 6146 load_episode=episode_name, 6147 parameters_update=parameters_update, 6148 load_epoch=load_epoch, 6149 data_path=data_path, 6150 file_paths=file_paths, 6151 mode=mode, 6152 ) 6153 os.makedirs(plot_path, exist_ok=True) 6154 task.visualize_results( 6155 save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"), 6156 add_legend=add_legend, 6157 ground_truth=ground_truth, 6158 colormap=colormap, 6159 hide_axes=hide_axes, 6160 min_classes=min_classes, 6161 whole_video=whole_video, 6162 transparent=transparent, 6163 dataset=mode, 6164 drop_classes=drop_classes, 6165 search_classes=search_classes, 6166 width=width, 6167 smooth_interval_prediction=smooth_interval_prediction, 6168 font_size=font_size, 6169 num_plots=num_plots, 6170 window_size=window_size, 6171 ) 6172 if remove_saved_features: 6173 self._remove_stores(parameters) 6174 6175 def create_video_from_labels( 6176 self, 6177 video_dir_path: str, 6178 mode="ground_truth", 6179 prediction_name: str = None, 6180 save_path: str = None, 6181 ): 6182 if save_path is None: 6183 save_path = os.path.join( 6184 self.project_path, "results", f"annotated_videos_from_{mode}" 6185 ) 6186 os.makedirs(save_path, exist_ok=True) 6187 6188 params = self._read_parameters(catch_blanks=False) 6189 6190 if mode == "ground_truth": 6191 source_dir = self.annotation_path 6192 annotation_suffix = params["data"]["annotation_suffix"] 6193 elif mode == "prediction": 6194 assert ( 6195 not prediction_name is None 6196 ), "Please provide a prediction name with mode 'prediction'" 6197 source_dir = os.path.join( 6198 self.project_path, "results", "predictions", prediction_name 6199 ) 6200 annotation_suffix = f"_{prediction_name}_prediction.pickle" 6201 6202 video_annotation_pairs = [ 6203 ( 6204 os.path.join(video_dir_path, f), 6205 os.path.join( 6206 source_dir, f.replace(f.split(".")[-1], annotation_suffix) 6207 ), 6208 ) 6209 for f in os.listdir(video_dir_path) 6210 if os.path.exists( 6211 os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix)) 6212 ) 6213 ] 6214 6215 for video_file, annotation_file in tqdm(video_annotation_pairs): 6216 if not os.path.exists(video_file): 6217 print(f"Video file {video_file} does not exist, skipping.") 6218 continue 6219 if not os.path.exists(annotation_file): 6220 print(f"Annotation file {annotation_file} does not exist, skipping.") 6221 continue 6222 6223 if annotation_file.endswith(".pickle"): 6224 annotations = load_pickle(annotation_file) 6225 elif annotation_file.endswith(".csv"): 6226 annotations = pd.read_csv(annotation_file) 6227 6228 if mode == "ground_truth": 6229 behaviors = annotations[1] 6230 annot_data = annotations[3] 6231 elif mode == "predictions": 6232 behaviors = list(annotations["classes"].values()) 6233 annot_data = [ 6234 annotations[key] 6235 for key in annotations.keys() 6236 if key not in ["classes", "min_frame", "max_frame"] 6237 ] 6238 if params["general"]["exclusive"]: 6239 annot_data = [np.argmax(annot, axis=1) for annot in annot_data] 6240 seqs = [ 6241 [ 6242 self._bin_array_to_sequences(annot, target_value=k) 6243 for k in range(len(behaviors)) 6244 ] 6245 for annot in annot_data 6246 ] 6247 else: 6248 annot_data = [np.where(annot > 0.5)[0] for annot in annot_data] 6249 seqs = [ 6250 self._bin_array_to_sequences(annot, target_value=1) 6251 for annot in annot_data 6252 ] 6253 annotations = ["", "", seqs] 6254 6255 for individual in annotations[3]: 6256 for behavior in annotations[3][individual]: 6257 intervals = annotations[3][individual][behavior] 6258 self._extract_videos( 6259 video_file, 6260 intervals, 6261 behavior, 6262 individual, 6263 save_path, 6264 resolution=(640, 480), 6265 fps=30, 6266 ) 6267 6268 def _bin_array_to_sequences( 6269 self, annot_data: List[np.ndarray], target_value: int 6270 ) -> List[List[Tuple[int, int]]]: 6271 is_target = annot_data == target_value 6272 changes = np.diff(np.concatenate(([False], is_target, [False]))) 6273 indices = np.where(changes)[0].reshape(-1, 2) 6274 subsequences = [list(range(start, end)) for start, end in indices] 6275 return subsequences 6276 6277 def _extract_videos( 6278 self, 6279 video_file: str, 6280 intervals: np.ndarray, 6281 behavior: str, 6282 individual: str, 6283 video_dir: str, 6284 resolution: Tuple[int, int] = (640, 480), 6285 fps: int = 30, 6286 ) -> None: 6287 """Extract frames from a video file from frames in between intervals in behavior folder for a given individual""" 6288 cap = cv2.VideoCapture(video_file) 6289 print("Extracting frames from", video_file) 6290 6291 for start, end, confusion in tqdm(intervals): 6292 6293 frame_count = start 6294 assert start < end, "Start frame should be less than end frame" 6295 if confusion > 0.5: 6296 continue 6297 cap.set(cv2.CAP_PROP_POS_FRAMES, start) 6298 output_file = os.path.join( 6299 video_dir, 6300 individual, 6301 behavior, 6302 os.path.splitext(os.path.basename(video_file))[0] 6303 + f"vid_{individual}_{behavior}_{start:05d}_{end:05d}.mp4", 6304 ) 6305 fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec, e.g., 'XVID', 'MJPG' 6306 out = cv2.VideoWriter( 6307 output_file, fourcc, fps, (resolution[0], resolution[1]) 6308 ) 6309 while cap.isOpened(): 6310 ret, frame = cap.read() 6311 if not ret: 6312 break 6313 6314 # Resize large frames 6315 frame = cv2.resize(frame, (640, 480)) 6316 out.write(frame) 6317 6318 frame_count += 1 6319 # Break if end frame is reached or max frames per behavior is reached 6320 if frame_count == end: 6321 break 6322 if frame_count <= 2: 6323 os.remove(output_file) 6324 # cap.release() 6325 out.release() 6326 6327 def create_metadata_backup(self) -> None: 6328 """Create a copy of the meta files.""" 6329 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 6330 meta_path = os.path.join(self.project_path, "meta") 6331 if os.path.exists(meta_copy_path): 6332 shutil.rmtree(meta_copy_path) 6333 os.mkdir(meta_copy_path) 6334 for file in os.listdir(meta_path): 6335 if file == "backup": 6336 continue 6337 if os.path.isdir(os.path.join(meta_path, file)): 6338 continue 6339 shutil.copy( 6340 os.path.join(meta_path, file), os.path.join(meta_copy_path, file) 6341 ) 6342 6343 def load_metadata_backup(self) -> None: 6344 """Load from previously created meta data backup (in case of corruption).""" 6345 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 6346 meta_path = os.path.join(self.project_path, "meta") 6347 for file in os.listdir(meta_copy_path): 6348 shutil.copy( 6349 os.path.join(meta_copy_path, file), os.path.join(meta_path, file) 6350 ) 6351 6352 def get_behavior_dictionary(self, episode_name: str) -> Dict: 6353 """Get the behavior dictionary for an episode. 6354 6355 Parameters 6356 ---------- 6357 episode_name : str 6358 the name of the episode 6359 6360 Returns 6361 ------- 6362 behaviors_dictionary : dict 6363 a dictionary where keys are label indices and values are label names 6364 6365 """ 6366 return self._episode(episode_name).get_behaviors_dict() 6367 6368 def import_episodes( 6369 self, 6370 episodes_directory: str, 6371 name_map: Dict = None, 6372 repeat_policy: str = "error", 6373 ) -> None: 6374 """Import episodes exported with `Project.export_episodes`. 6375 6376 Parameters 6377 ---------- 6378 episodes_directory : str 6379 the path to the exported episodes directory 6380 name_map : dict, optional 6381 a name change dictionary for the episodes: keys are old names, values are new names 6382 repeat_policy : {'error', 'skip', 'force'}, default 'error' 6383 the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates, 6384 'force' overwrites existing episodes 6385 6386 """ 6387 if name_map is None: 6388 name_map = {} 6389 episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle")) 6390 to_remove = [] 6391 import_string = "Imported episodes: " 6392 for episode_name in episodes.index: 6393 if episode_name in name_map: 6394 import_string += f"{episode_name} " 6395 episode_name = name_map[episode_name] 6396 import_string += f"({episode_name}), " 6397 else: 6398 import_string += f"{episode_name}, " 6399 try: 6400 self._check_episode_validity(episode_name, allow_doublecolon=True) 6401 except ValueError as e: 6402 if str(e).endswith("is already taken!"): 6403 if repeat_policy == "skip": 6404 to_remove.append(episode_name) 6405 elif repeat_policy == "force": 6406 self.remove_episode(episode_name) 6407 elif repeat_policy == "error": 6408 raise ValueError( 6409 f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it" 6410 ) 6411 else: 6412 raise ValueError( 6413 f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']" 6414 ) 6415 episodes = episodes.drop(index=to_remove) 6416 self._episodes().update( 6417 episodes, 6418 name_map=name_map, 6419 force=(repeat_policy == "force"), 6420 data_path=self.data_path, 6421 annotation_path=self.annotation_path, 6422 ) 6423 for episode_name in episodes.index: 6424 if episode_name in name_map: 6425 new_episode_name = name_map[episode_name] 6426 else: 6427 new_episode_name = episode_name 6428 model_dir = os.path.join( 6429 self.project_path, "results", "model", new_episode_name 6430 ) 6431 old_model_dir = os.path.join(episodes_directory, "model", episode_name) 6432 if os.path.exists(model_dir): 6433 shutil.rmtree(model_dir) 6434 os.mkdir(model_dir) 6435 for file in os.listdir(old_model_dir): 6436 shutil.copyfile( 6437 os.path.join(old_model_dir, file), os.path.join(model_dir, file) 6438 ) 6439 log_file = os.path.join( 6440 self.project_path, "results", "logs", f"{new_episode_name}.txt" 6441 ) 6442 old_log_file = os.path.join( 6443 episodes_directory, "logs", f"{episode_name}.txt" 6444 ) 6445 shutil.copyfile(old_log_file, log_file) 6446 print(import_string) 6447 print("\n") 6448 6449 def export_episodes( 6450 self, episode_names: List, output_directory: str, name: str = None 6451 ) -> None: 6452 """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`. 6453 6454 Parameters 6455 ---------- 6456 episode_names : list 6457 a list of string episode names 6458 output_directory : str 6459 the path to the directory where the episodes will be saved 6460 name : str, optional 6461 the name of the episodes directory (by default `exported_episodes`) 6462 6463 """ 6464 if name is None: 6465 name = "exported_episodes" 6466 if os.path.exists( 6467 os.path.join(output_directory, name + ".zip") 6468 ) or os.path.exists(os.path.join(output_directory, name)): 6469 i = 1 6470 while os.path.exists( 6471 os.path.join(output_directory, name + f"_{i}.zip") 6472 ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")): 6473 i += 1 6474 name = name + f"_{i}" 6475 dest_dir = os.path.join(output_directory, name) 6476 os.mkdir(dest_dir) 6477 os.mkdir(os.path.join(dest_dir, "model")) 6478 os.mkdir(os.path.join(dest_dir, "logs")) 6479 runs = [] 6480 for episode in episode_names: 6481 runs += self._episodes().get_runs(episode) 6482 for run in runs: 6483 shutil.copytree( 6484 os.path.join(self.project_path, "results", "model", run), 6485 os.path.join(dest_dir, "model", run), 6486 ) 6487 shutil.copyfile( 6488 os.path.join(self.project_path, "results", "logs", f"{run}.txt"), 6489 os.path.join(dest_dir, "logs", f"{run}.txt"), 6490 ) 6491 data = self._episodes().get_subset(runs) 6492 data.to_pickle(os.path.join(dest_dir, "episodes.pickle")) 6493 6494 def get_results_table( 6495 self, 6496 episode_names: List, 6497 metrics: List = None, 6498 mode: str = "mean", # Choose between ["mean", "statistics", "detail"] 6499 print_results: bool = True, 6500 classes: List = None, 6501 ): 6502 """Generate a `pandas` dataframe with a summary of episode results. 6503 6504 Parameters 6505 ---------- 6506 episode_names : list 6507 a list of names of episodes to include 6508 metrics : list, optional 6509 a list of metric names to include 6510 mode : bool, optional 6511 the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean" 6512 print_results : bool, optional 6513 if True, the results will be printed to the console, by default True 6514 classes : list, optional 6515 a list of names of classes to include (by default all are included) 6516 6517 Returns 6518 ------- 6519 results : pd.DataFrame 6520 a table with the results 6521 6522 """ 6523 run_names = [] 6524 for episode in episode_names: 6525 run_names += self._episodes().get_runs(episode) 6526 episodes = self.list_episodes(run_names, print_results=False) 6527 metric_columns = [x for x in episodes.columns if x[0] == "results"] 6528 results_df = pd.DataFrame() 6529 if metrics is not None: 6530 metric_columns = [ 6531 x for x in metric_columns if x[1].split("_")[0] in metrics 6532 ] 6533 for episode in episode_names: 6534 results = [] 6535 metric_set = set() 6536 for run in self._episodes().get_runs(episode): 6537 beh_dict = self.get_behavior_dictionary(run) 6538 res_dict = defaultdict(lambda: {}) 6539 for column in metric_columns: 6540 if np.isnan(episodes.loc[run, column]): 6541 continue 6542 split = column[1].split("_") 6543 if split[-1].isnumeric(): 6544 beh_ind = int(split[-1]) 6545 metric_name = "_".join(split[:-1]) 6546 beh = beh_dict[beh_ind] 6547 else: 6548 beh = "average" 6549 metric_name = column[1] 6550 res_dict[beh][metric_name] = episodes.loc[run, column] 6551 metric_set.add(metric_name) 6552 if "average" not in res_dict: 6553 res_dict["average"] = {} 6554 for metric in metric_set: 6555 if metric not in res_dict["average"]: 6556 arr = [ 6557 res_dict[beh][metric] 6558 for beh in res_dict 6559 if metric in res_dict[beh] 6560 ] 6561 res_dict["average"][metric] = np.mean(arr) 6562 results.append(res_dict) 6563 episode_results = {} 6564 for metric in metric_set: 6565 for beh in results[0].keys(): 6566 if classes is not None and beh not in classes: 6567 continue 6568 arr = [] 6569 for res_dict in results: 6570 if metric in res_dict[beh]: 6571 arr.append(res_dict[beh][metric]) 6572 if len(arr) > 0: 6573 if mode == "statistics": 6574 episode_results[(beh, f"{episode} {metric} mean")] = ( 6575 np.mean(arr) 6576 ) 6577 episode_results[(beh, f"{episode} {metric} std")] = np.std( 6578 arr 6579 ) 6580 elif mode == "mean": 6581 episode_results[(beh, f"{episode} {metric}")] = np.mean(arr) 6582 elif mode == "detail": 6583 for i, val in enumerate(arr): 6584 episode_results[(beh, f"{episode}::{i} {metric}")] = val 6585 for key, value in episode_results.items(): 6586 results_df.loc[key[0], key[1]] = value 6587 if print_results: 6588 print(f"RESULTS:") 6589 print(results_df) 6590 print("\n") 6591 return results_df 6592 6593 def episode_exists(self, episode_name: str) -> bool: 6594 """Check if an episode already exists. 6595 6596 Parameters 6597 ---------- 6598 episode_name : str 6599 the episode name 6600 6601 Returns 6602 ------- 6603 exists : bool 6604 `True` if the episode exists 6605 6606 """ 6607 return self._episodes().check_name_validity(episode_name) 6608 6609 def search_exists(self, search_name: str) -> bool: 6610 """Check if a search already exists. 6611 6612 Parameters 6613 ---------- 6614 search_name : str 6615 the search name 6616 6617 Returns 6618 ------- 6619 exists : bool 6620 `True` if the search exists 6621 6622 """ 6623 return self._searches().check_name_validity(search_name) 6624 6625 def prediction_exists(self, prediction_name: str) -> bool: 6626 """Check if a prediction already exists. 6627 6628 Parameters 6629 ---------- 6630 prediction_name : str 6631 the prediction name 6632 6633 Returns 6634 ------- 6635 exists : bool 6636 `True` if the prediction exists 6637 6638 """ 6639 return self._predictions().check_name_validity(prediction_name) 6640 6641 @staticmethod 6642 def project_name_available(projects_path: str, project_name: str): 6643 """Check if a project name is available. 6644 6645 Parameters 6646 ---------- 6647 projects_path : str 6648 the path to the projects directory 6649 project_name : str 6650 the name of the project to check 6651 6652 Returns 6653 ------- 6654 available : bool 6655 `True` if the project name is available 6656 6657 """ 6658 if projects_path is None: 6659 projects_path = os.path.join(str(Path.home()), "DLC2Action") 6660 return not os.path.exists(os.path.join(projects_path, project_name)) 6661 6662 def _update_episode_metrics(self, episode_name: str, metrics: Dict): 6663 """Update meta data with evaluation results. 6664 6665 Parameters 6666 ---------- 6667 episode_name : str 6668 the name of the episode 6669 metrics : dict 6670 the metrics dictionary to update with 6671 6672 """ 6673 self._episodes().update_episode_metrics(episode_name, metrics) 6674 6675 def rename_episode(self, episode_name: str, new_episode_name: str): 6676 """Rename an episode. 6677 6678 Parameters 6679 ---------- 6680 episode_name : str 6681 the current episode name 6682 new_episode_name : str 6683 the new episode name 6684 6685 """ 6686 shutil.move( 6687 os.path.join(self.project_path, "results", "model", episode_name), 6688 os.path.join(self.project_path, "results", "model", new_episode_name), 6689 ) 6690 shutil.move( 6691 os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"), 6692 os.path.join( 6693 self.project_path, "results", "logs", f"{new_episode_name}.txt" 6694 ), 6695 ) 6696 self._episodes().rename_episode(episode_name, new_episode_name)
A class to create and maintain the project files + keep track of experiments.
58 def __init__( 59 self, 60 name: str, 61 data_type: str = None, 62 annotation_type: str = "none", 63 projects_path: str = None, 64 data_path: Union[str, List] = None, 65 annotation_path: Union[str, List] = None, 66 copy: bool = False, 67 ) -> None: 68 """Initialize the class. 69 70 Parameters 71 ---------- 72 name : str 73 name of the project 74 data_type : str, optional 75 data type (run Project.data_types() to see available options; has to be provided if the project is being 76 created) 77 annotation_type : str, default 'none' 78 annotation type (run Project.annotation_types() to see available options) 79 projects_path : str, optional 80 path to the projects folder (is filled with ~/DLC2Action by default) 81 data_path : str, optional 82 path to the folder containing input files for the project (has to be provided if the project is being 83 created) 84 annotation_path : str, optional 85 path to the folder containing annotation files for the project 86 copy : bool, default False 87 if True, the files from annotation_path and data_path will be copied to the projects folder; 88 otherwise they will be moved 89 90 """ 91 if projects_path is None: 92 projects_path = os.path.join(str(Path.home()), "DLC2Action") 93 if not os.path.exists(projects_path): 94 os.mkdir(projects_path) 95 self.project_path = os.path.join(projects_path, name) 96 self.name = name 97 self.data_type = data_type 98 self.annotation_type = annotation_type 99 self.data_path = data_path 100 self.annotation_path = annotation_path 101 if not os.path.exists(self.project_path): 102 if data_type is None: 103 raise ValueError( 104 "The data_type parameter is necessary when creating a new project!" 105 ) 106 self._initialize_project( 107 data_type, annotation_type, data_path, annotation_path, copy 108 ) 109 else: 110 self.annotation_type, self.data_type = self._read_types() 111 if data_type != self.data_type and data_type is not None: 112 raise ValueError( 113 f"The project has already been initialized with data_type={self.data_type}!" 114 ) 115 if annotation_type != self.annotation_type and annotation_type != "none": 116 raise ValueError( 117 f"The project has already been initialized with annotation_type={self.annotation_type}!" 118 ) 119 self.annotation_path, data_path = self._read_paths() 120 if self.data_path is None: 121 self.data_path = data_path 122 # if data_path != self.data_path and data_path is not None: 123 # raise ValueError( 124 # f"The project has already been initialized with data_path={self.data_path}!" 125 # ) 126 if annotation_path != self.annotation_path and annotation_path is not None: 127 raise ValueError( 128 f"The project has already been initialized with annotation_path={self.annotation_path}!" 129 ) 130 self._update_configs()
Initialize the class.
Parameters
name : str name of the project data_type : str, optional data type (run Project.data_types() to see available options; has to be provided if the project is being created) annotation_type : str, default 'none' annotation type (run Project.annotation_types() to see available options) projects_path : str, optional path to the projects folder (is filled with ~/DLC2Action by default) data_path : str, optional path to the folder containing input files for the project (has to be provided if the project is being created) annotation_path : str, optional path to the folder containing annotation files for the project copy : bool, default False if True, the files from annotation_path and data_path will be copied to the projects folder; otherwise they will be moved
563 def get_decision_thresholds( 564 self, 565 episode_names: List, 566 metric_name: str = "f1", 567 parameters_update: Dict = None, 568 load_epochs: List = None, 569 remove_saved_features: bool = False, 570 ) -> Tuple[List, List, TaskDispatcher]: 571 """Compute optimal decision thresholds or load them if they have been computed before. 572 573 Parameters 574 ---------- 575 episode_names : List 576 a list of episode names 577 metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"} 578 the metric to optimize 579 parameters_update : dict, optional 580 the parameter update dictionary 581 load_epochs : list, optional 582 a list of epochs to load (by default last are loaded) 583 remove_saved_features : bool, default False 584 if `True`, the dataset will be deleted after the computation 585 586 Returns 587 ------- 588 thresholds : list 589 a list of float decision threshold values 590 classes : list 591 the label names corresponding to the values 592 task : TaskDispatcher | None 593 the task used in computation 594 595 """ 596 parameters = self._make_parameters( 597 "_", 598 episode_names[0], 599 parameters_update, 600 {}, 601 load_epochs[0], 602 purpose="prediction", 603 ) 604 thresholds = self._thresholds().find_thresholds( 605 episode_names, 606 load_epochs, 607 metric_name, 608 metric_parameters=parameters["metrics"][metric_name], 609 ) 610 task = None 611 behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values()) 612 return thresholds, behaviors, task
Compute optimal decision thresholds or load them if they have been computed before.
Parameters
episode_names : List
    a list of episode names
metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"}
    the metric to optimize
parameters_update : dict, optional
    the parameter update dictionary
load_epochs : list, optional
    a list of epochs to load (by default last are loaded)
remove_saved_features : bool, default False
    if True, the dataset will be deleted after the computation
Returns
thresholds : list a list of float decision threshold values classes : list the label names corresponding to the values task : TaskDispatcher | None the task used in computation
614 def run_episode( 615 self, 616 episode_name: str, 617 load_episode: str = None, 618 parameters_update: Dict = None, 619 task: TaskDispatcher = None, 620 load_epoch: int = None, 621 load_search: str = None, 622 load_parameters: list = None, 623 round_to_binary: list = None, 624 load_strict: bool = True, 625 n_seeds: int = 1, 626 force: bool = False, 627 suppress_name_check: bool = False, 628 remove_saved_features: bool = False, 629 mask_name: str = None, 630 autostop_metric: str = None, 631 autostop_interval: int = 50, 632 autostop_threshold: float = 0.001, 633 loading_bar: bool = False, 634 trial: Tuple = None, 635 ) -> TaskDispatcher: 636 """Run an episode. 637 638 The task parameters are read from the config files and then updated with the 639 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 640 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 641 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 642 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 643 data parameters are used. 644 645 You can use the autostop parameters to finish training when the parameters are not improving. It will be 646 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 647 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 648 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 649 650 Parameters 651 ---------- 652 episode_name : str 653 the episode name 654 load_episode : str, optional 655 the (previously run) episode name to load the model from; if the episode has multiple runs, 656 the new episode will have the same number of runs, each starting with one of the pre-trained models 657 parameters_update : dict, optional 658 the dictionary used to update the parameters from the config files 659 task : TaskDispatcher, optional 660 a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating 661 a new instance) 662 load_epoch : int, optional 663 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 664 load_search : str, optional 665 the hyperparameter search result to load 666 load_parameters : list, optional 667 a list of string names of the parameters to load from load_search (if not provided, all parameters 668 are loaded) 669 round_to_binary : list, optional 670 a list of string names of the loaded parameters that should be rounded to the nearest power of two 671 load_strict : bool, default True 672 if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and 673 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` 674 n_seeds : int, default 1 675 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g. 676 `test_episode#0` and `test_episode#1` 677 force : bool, default False 678 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 679 suppress_name_check : bool, default False 680 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 681 why they are usually forbidden) 682 remove_saved_features : bool, default False 683 if `True`, the dataset will be deleted after training 684 mask_name : str, optional 685 the name of the real_lens to apply 686 autostop_metric : str, optional 687 the autostop metric (can be any one of the tracked metrics of `'loss'`) 688 autostop_interval : int, default 50 689 the number of epochs to average the autostop metric over 690 autostop_threshold : float, default 0.001 691 the autostop difference threshold 692 loading_bar : bool, default False 693 if `True`, a loading bar will be displayed 694 trial : tuple, optional 695 a tuple of (trial, metric) for hyperparameter search 696 697 Returns 698 ------- 699 TaskDispatcher 700 the `TaskDispatcher` object 701 702 """ 703 704 import gc 705 706 gc.collect() 707 if torch.cuda.is_available(): 708 torch.cuda.empty_cache() 709 710 if type(n_seeds) is not int or n_seeds < 1: 711 raise ValueError( 712 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 713 ) 714 if n_seeds > 1 and mask_name is not None: 715 raise ValueError("Cannot apply a real_lens with n_seeds > 1") 716 self._check_episode_validity( 717 episode_name, allow_doublecolon=suppress_name_check, force=force 718 ) 719 load_runs = self._episodes().get_runs(load_episode) 720 if len(load_runs) > 1: 721 task = self.run_episodes( 722 episode_names=[ 723 f'{episode_name}#{run.split("#")[-1]}' for run in load_runs 724 ], 725 load_episodes=load_runs, 726 parameters_updates=[parameters_update for _ in load_runs], 727 load_epochs=[load_epoch for _ in load_runs], 728 load_searches=[load_search for _ in load_runs], 729 load_parameters=[load_parameters for _ in load_runs], 730 round_to_binary=[round_to_binary for _ in load_runs], 731 load_strict=[load_strict for _ in load_runs], 732 suppress_name_check=True, 733 force=force, 734 remove_saved_features=False, 735 ) 736 if remove_saved_features: 737 self._remove_stores( 738 { 739 "general": task.general_parameters, 740 "data": task.data_parameters, 741 "features": task.feature_parameters, 742 } 743 ) 744 if n_seeds > 1: 745 warnings.warn( 746 f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs" 747 ) 748 elif n_seeds > 1: 749 750 self.run_episodes( 751 episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)], 752 load_episodes=[load_episode for _ in range(n_seeds)], 753 parameters_updates=[parameters_update for _ in range(n_seeds)], 754 load_epochs=[load_epoch for _ in range(n_seeds)], 755 load_searches=[load_search for _ in range(n_seeds)], 756 load_parameters=[load_parameters for _ in range(n_seeds)], 757 round_to_binary=[round_to_binary for _ in range(n_seeds)], 758 load_strict=[load_strict for _ in range(n_seeds)], 759 suppress_name_check=True, 760 force=force, 761 remove_saved_features=remove_saved_features, 762 ) 763 else: 764 print(f"TRAINING {episode_name}") 765 try: 766 task, parameters = self._make_task_training( 767 episode_name, 768 load_episode, 769 parameters_update, 770 load_epoch, 771 load_search, 772 load_parameters, 773 round_to_binary, 774 continuing=False, 775 task=task, 776 mask_name=mask_name, 777 load_strict=load_strict, 778 ) 779 self._save_episode( 780 episode_name, 781 parameters, 782 task.behaviors_dict(), 783 norm_stats=task.get_normalization_stats(), 784 ) 785 time_start = time.time() 786 if trial is not None: 787 trial, metric = trial 788 else: 789 trial, metric = None, None 790 logs = task.train( 791 autostop_metric=autostop_metric, 792 autostop_interval=autostop_interval, 793 autostop_threshold=autostop_threshold, 794 loading_bar=loading_bar, 795 trial=trial, 796 optimized_metric=metric, 797 ) 798 time_end = time.time() 799 time_total = time_end - time_start 800 hours = int(time_total // 3600) 801 time_total -= hours * 3600 802 minutes = int(time_total // 60) 803 time_total -= minutes * 60 804 seconds = int(time_total) 805 training_time = f"{hours}:{minutes:02}:{seconds:02}" 806 self._update_episode_results(episode_name, logs, training_time) 807 if remove_saved_features: 808 self._remove_stores(parameters) 809 print("\n") 810 return task 811 812 except Exception as e: 813 if isinstance(e, optuna.exceptions.TrialPruned): 814 raise e 815 else: 816 # if str(e) != f"The {episode_name} episode name is already in use!": 817 # self.remove_episode(episode_name) 818 raise RuntimeError(f"Episode {episode_name} could not run")
Run an episode.
The task parameters are read from the config files and then updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.
You can use the autostop parameters to finish training when the parameters are not improving. It will be
stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than
the average over the previous autostop_interval epochs + autostop_threshold. For example, if the
current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.
Parameters
episode_name : str
    the episode name
load_episode : str, optional
    the (previously run) episode name to load the model from; if the episode has multiple runs,
    the new episode will have the same number of runs, each starting with one of the pre-trained models
parameters_update : dict, optional
    the dictionary used to update the parameters from the config files
task : TaskDispatcher, optional
    a pre-existing TaskDispatcher object (if provided, the method will update it instead of creating
    a new instance)
load_epoch : int, optional
    the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
load_search : str, optional
    the hyperparameter search result to load
load_parameters : list, optional
    a list of string names of the parameters to load from load_search (if not provided, all parameters
    are loaded)
round_to_binary : list, optional
    a list of string names of the loaded parameters that should be rounded to the nearest power of two
load_strict : bool, default True
    if False, matching weights will be loaded from load_episode and differences in parameter name lists and
    weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError
n_seeds : int, default 1
    the number of runs to perform; if n_seeds > 1, the episodes will be named episode_name#run_index, e.g.
    test_episode#0 and test_episode#1
force : bool, default False
    if True and an episode with name episode_name already exists, it will be overwritten (use with caution!)
suppress_name_check : bool, default False
    if True, episode names with a double colon are allowed (please don't use this option unless you understand
    why they are usually forbidden)
remove_saved_features : bool, default False
    if True, the dataset will be deleted after training
mask_name : str, optional
    the name of the real_lens to apply
autostop_metric : str, optional
    the autostop metric (can be any one of the tracked metrics of 'loss')
autostop_interval : int, default 50
    the number of epochs to average the autostop metric over
autostop_threshold : float, default 0.001
    the autostop difference threshold
loading_bar : bool, default False
    if True, a loading bar will be displayed
trial : tuple, optional
    a tuple of (trial, metric) for hyperparameter search
Returns
TaskDispatcher
    the TaskDispatcher object
820 def run_episodes( 821 self, 822 episode_names: List, 823 load_episodes: List = None, 824 parameters_updates: List = None, 825 load_epochs: List = None, 826 load_searches: List = None, 827 load_parameters: List = None, 828 round_to_binary: List = None, 829 load_strict: List = None, 830 force: bool = False, 831 suppress_name_check: bool = False, 832 remove_saved_features: bool = False, 833 ) -> TaskDispatcher: 834 """Run multiple episodes in sequence (and re-use previously loaded information). 835 836 For each episode, the task parameters are read from the config files and then updated with the 837 parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the 838 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 839 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 840 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 841 data parameters are used. 842 843 Parameters 844 ---------- 845 episode_names : list 846 a list of strings of episode names 847 load_episodes : list, optional 848 a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, 849 the new episode will have the same number of runs, each starting with one of the pre-trained models 850 parameters_updates : list, optional 851 a list of dictionaries used to update the parameters from the config 852 load_epochs : list, optional 853 a list of integers used to specify the epoch to load (if load_episodes is not None) 854 load_searches : list, optional 855 a list of strings of hyperparameter search results to load 856 load_parameters : list, optional 857 a list of lists of string names of the parameters to load from the searches 858 round_to_binary : list, optional 859 a list of string names of the loaded parameters that should be rounded to the nearest power of two 860 load_strict : list, optional 861 a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from 862 the corresponding episode and differences in parameter name lists and 863 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for 864 every episode) 865 force : bool, default False 866 if `True` and an episode name is already taken, it will be overwritten (use with caution!) 867 suppress_name_check : bool, default False 868 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 869 why they are usually forbidden) 870 remove_saved_features : bool, default False 871 if `True`, the dataset will be deleted after training 872 873 Returns 874 ------- 875 TaskDispatcher 876 the task dispatcher object 877 878 """ 879 task = None 880 if load_searches is None: 881 load_searches = [None for _ in episode_names] 882 if load_episodes is None: 883 load_episodes = [None for _ in episode_names] 884 if parameters_updates is None: 885 parameters_updates = [None for _ in episode_names] 886 if load_parameters is None: 887 load_parameters = [None for _ in episode_names] 888 if load_epochs is None: 889 load_epochs = [None for _ in episode_names] 890 if load_strict is None: 891 load_strict = [True for _ in episode_names] 892 for ( 893 parameters_update, 894 episode_name, 895 load_episode, 896 load_epoch, 897 load_search, 898 load_parameters_list, 899 load_strict_value, 900 ) in zip( 901 parameters_updates, 902 episode_names, 903 load_episodes, 904 load_epochs, 905 load_searches, 906 load_parameters, 907 load_strict, 908 ): 909 task = self.run_episode( 910 episode_name, 911 load_episode, 912 parameters_update, 913 task, 914 load_epoch, 915 load_search, 916 load_parameters_list, 917 round_to_binary, 918 load_strict_value, 919 suppress_name_check=suppress_name_check, 920 force=force, 921 remove_saved_features=remove_saved_features, 922 ) 923 return task
Run multiple episodes in sequence (and re-use previously loaded information).
For each episode, the task parameters are read from the config files and then updated with the parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.
Parameters
episode_names : list
    a list of strings of episode names
load_episodes : list, optional
    a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
    the new episode will have the same number of runs, each starting with one of the pre-trained models
parameters_updates : list, optional
    a list of dictionaries used to update the parameters from the config
load_epochs : list, optional
    a list of integers used to specify the epoch to load (if load_episodes is not None)
load_searches : list, optional
    a list of strings of hyperparameter search results to load
load_parameters : list, optional
    a list of lists of string names of the parameters to load from the searches
round_to_binary : list, optional
    a list of string names of the loaded parameters that should be rounded to the nearest power of two
load_strict : list, optional
    a list of boolean values specifying weight loading policy: if False, matching weights will be loaded from
    the corresponding episode and differences in parameter name lists and
    weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError (by default True for
    every episode)
force : bool, default False
    if True and an episode name is already taken, it will be overwritten (use with caution!)
suppress_name_check : bool, default False
    if True, episode names with a double colon are allowed (please don't use this option unless you understand
    why they are usually forbidden)
remove_saved_features : bool, default False
    if True, the dataset will be deleted after training
Returns
TaskDispatcher the task dispatcher object
925 def continue_episode( 926 self, 927 episode_name: str, 928 num_epochs: int = None, 929 task: TaskDispatcher = None, 930 n_seeds: int = 1, 931 remove_saved_features: bool = False, 932 device: str = "cuda", 933 num_cpus: int = None, 934 ) -> TaskDispatcher: 935 """Load an older episode and continue running from the latest checkpoint. 936 937 All parameters as well as the model and optimizer state dictionaries are loaded from the episode. 938 939 Parameters 940 ---------- 941 episode_name : str 942 the name of the episode to continue 943 num_epochs : int, optional 944 the new number of epochs 945 task : TaskDispatcher, optional 946 a pre-existing task; if provided, the method will update the task instead of creating a new one 947 (this might save time, mainly on dataset loading) 948 n_seeds : int, default 1 949 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g. 950 `test_episode#0` and `test_episode#1` 951 remove_saved_features : bool, default False 952 if `True`, pre-computed features will be deleted after the run 953 device : str, default "cuda" 954 the torch device to use 955 num_cpus : int, optional 956 the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used 957 958 Returns 959 ------- 960 TaskDispatcher 961 the task dispatcher 962 963 """ 964 runs = self._episodes().get_runs(episode_name) 965 for run in runs: 966 print(f"TRAINING {run}") 967 if num_epochs is None and not self._episode(run).unfinished(): 968 continue 969 parameters_update = { 970 "training": { 971 "num_epochs": num_epochs, 972 "device": device, 973 }, 974 "general": {"num_cpus": num_cpus}, 975 } 976 task, parameters = self._make_task_training( 977 run, 978 load_episode=run, 979 parameters_update=parameters_update, 980 continuing=True, 981 task=task, 982 ) 983 time_start = time.time() 984 logs = task.train() 985 time_end = time.time() 986 old_time = self._training_time(run) 987 if not np.isnan(old_time): 988 time_end += old_time 989 time_total = time_end - time_start 990 hours = int(time_total // 3600) 991 time_total -= hours * 3600 992 minutes = int(time_total // 60) 993 time_total -= minutes * 60 994 seconds = int(time_total) 995 training_time = f"{hours}:{minutes:02}:{seconds:02}" 996 else: 997 training_time = np.nan 998 self._save_episode( 999 run, 1000 parameters, 1001 task.behaviors_dict(), 1002 suppress_validation=True, 1003 training_time=training_time, 1004 norm_stats=task.get_normalization_stats(), 1005 ) 1006 self._update_episode_results(run, logs) 1007 print("\n") 1008 if len(runs) < n_seeds: 1009 for i in range(len(runs), n_seeds): 1010 self.run_episode( 1011 f"{episode_name}#{i}", 1012 parameters_update=self._episodes().load_parameters(runs[0]), 1013 task=task, 1014 suppress_name_check=True, 1015 ) 1016 if remove_saved_features: 1017 self._remove_stores(parameters) 1018 return task
Load an older episode and continue running from the latest checkpoint.
All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
Parameters
episode_name : str
    the name of the episode to continue
num_epochs : int, optional
    the new number of epochs
task : TaskDispatcher, optional
    a pre-existing task; if provided, the method will update the task instead of creating a new one
    (this might save time, mainly on dataset loading)
n_seeds : int, default 1
    the number of runs to perform; if n_seeds > 1, the episodes will be named episode_name#run_index, e.g.
    test_episode#0 and test_episode#1
remove_saved_features : bool, default False
    if True, pre-computed features will be deleted after the run
device : str, default "cuda"
    the torch device to use
num_cpus : int, optional
    the number of CPUs to use for data loading; if None, the number of available CPUs will be used
Returns
TaskDispatcher the task dispatcher
1020 def run_default_hyperparameter_search( 1021 self, 1022 search_name: str, 1023 model_name: str, 1024 metric: str = "f1", 1025 best_n: int = 3, 1026 direction: str = "maximize", 1027 load_episode: str = None, 1028 load_epoch: int = None, 1029 load_strict: bool = True, 1030 prune: bool = True, 1031 force: bool = False, 1032 remove_saved_features: bool = False, 1033 overlap: float = 0, 1034 num_epochs: int = 50, 1035 test_frac: float = None, 1036 n_trials=150, 1037 batch_size=32, 1038 ): 1039 """Run an optuna hyperparameter search with default parameters for a model. 1040 1041 For the vast majority of cases, optimizing the default parameters should be enough. 1042 Check out `dlc2action.options.model_hyperparameters` for the lists of parameters. 1043 There are also options to set overlap, test fraction and number of epochs parameters for the search without 1044 modifying the project config files. However, if you want something more complex, look into 1045 `Project.run_hyperparameter_search`. 1046 1047 The task parameters are read from the config files and updated with the parameters_update dictionary. 1048 The model can be either initialized from scratch or loaded from a previously run episode. 1049 For each trial, the objective metric is averaged over a few best epochs. 1050 1051 Parameters 1052 ---------- 1053 search_name : str 1054 the name of the search to store it in the meta files and load in run_episode 1055 model_name : str 1056 the name 1057 metric : str 1058 the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to 1059 `"none"` in the config files, it will be reset to `"macro"` for the search 1060 best_n : int, default 1 1061 the number of epochs to average the metric; if 0, the last value is taken 1062 direction : {'maximize', 'minimize'} 1063 optimization direction 1064 load_episode : str, optional 1065 the name of the episode to load the model from 1066 load_epoch : int, optional 1067 the epoch to load the model from (if not provided, the last checkpoint is used) 1068 load_strict : bool, default True 1069 if `True`, the model will be loaded only if the parameters match exactly 1070 prune : bool, default False 1071 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1072 (with optuna HyperBand pruner) 1073 force : bool, default False 1074 if `True`, existing searches with the same name will be overwritten 1075 remove_saved_features : bool, default False 1076 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1077 overlap : float, default 0 1078 the overlap to use for the search 1079 num_epochs : int, default 50 1080 the number of epochs to use for the search 1081 test_frac : float, optional 1082 the test fraction to use for the search 1083 n_trials : int, default 150 1084 the number of trials to run 1085 batch_size : int, default 32 1086 the batch size to use for the search 1087 1088 Returns 1089 ------- 1090 best_parameters : dict 1091 a dictionary of best parameters 1092 1093 """ 1094 if model_name not in options.model_hyperparameters: 1095 raise ValueError( 1096 f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()" 1097 ) 1098 pars = { 1099 "general": {"overlap": overlap, "model_name": model_name}, 1100 "training": {"num_epochs": num_epochs, "batch_size": batch_size}, 1101 } 1102 if test_frac is not None: 1103 pars["training"]["test_frac"] = test_frac 1104 if not metric.split("_")[-1].isnumeric(): 1105 project_pars = self._read_parameters() 1106 if project_pars["metrics"][metric].get("average") == "none": 1107 pars["metrics"] = {metric: {"average": "macro"}} 1108 return self.run_hyperparameter_search( 1109 search_name=search_name, 1110 search_space=options.model_hyperparameters[model_name], 1111 metric=metric, 1112 n_trials=n_trials, 1113 best_n=best_n, 1114 parameters_update=pars, 1115 direction=direction, 1116 load_episode=load_episode, 1117 load_epoch=load_epoch, 1118 load_strict=load_strict, 1119 prune=prune, 1120 force=force, 1121 remove_saved_features=remove_saved_features, 1122 )
Run an optuna hyperparameter search with default parameters for a model.
For the vast majority of cases, optimizing the default parameters should be enough.
Check out dlc2action.options.model_hyperparameters for the lists of parameters.
There are also options to set overlap, test fraction and number of epochs parameters for the search without
modifying the project config files. However, if you want something more complex, look into
Project.run_hyperparameter_search.
The task parameters are read from the config files and updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from a previously run episode. For each trial, the objective metric is averaged over a few best epochs.
Parameters
search_name : str
    the name of the search to store it in the meta files and load in run_episode
model_name : str
    the name
metric : str
    the metric to maximize/minimize (see direction); if the metric has an "average" parameter and it is set to
    "none" in the config files, it will be reset to "macro" for the search
best_n : int, default 1
    the number of epochs to average the metric; if 0, the last value is taken
direction : {'maximize', 'minimize'}
    optimization direction
load_episode : str, optional
    the name of the episode to load the model from
load_epoch : int, optional
    the epoch to load the model from (if not provided, the last checkpoint is used)
load_strict : bool, default True
    if True, the model will be loaded only if the parameters match exactly
prune : bool, default False
    if True, experiments where the optimized metric is improving too slowly will be terminated
    (with optuna HyperBand pruner)
force : bool, default False
    if True, existing searches with the same name will be overwritten
remove_saved_features : bool, default False
    if True, pre-computed features will be deleted after each run (if the data parameters change)
overlap : float, default 0
    the overlap to use for the search
num_epochs : int, default 50
    the number of epochs to use for the search
test_frac : float, optional
    the test fraction to use for the search
n_trials : int, default 150
    the number of trials to run
batch_size : int, default 32
    the batch size to use for the search
Returns
best_parameters : dict a dictionary of best parameters
1124 def run_hyperparameter_search( 1125 self, 1126 search_name: str, 1127 search_space: Dict, 1128 metric: str = "f1", 1129 n_trials: int = 20, 1130 best_n: int = 1, 1131 parameters_update: Dict = None, 1132 direction: str = "maximize", 1133 load_episode: str = None, 1134 load_epoch: int = None, 1135 load_strict: bool = True, 1136 prune: bool = False, 1137 force: bool = False, 1138 remove_saved_features: bool = False, 1139 make_plots: bool = True, 1140 ) -> Dict: 1141 """Run an optuna hyperparameter search. 1142 1143 For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`. 1144 1145 To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is 1146 a dictionary where keys are model names and values are default search spaces. 1147 1148 The task parameters are read from the config files and updated with the parameters_update dictionary. 1149 The model can be either initialized from scratch or loaded from a previously run episode. 1150 For each trial, the objective metric is averaged over a few best epochs. 1151 1152 Parameters 1153 ---------- 1154 search_name : str 1155 the name of the search to store it in the meta files and load in run_episode 1156 search_space : dict 1157 a dictionary representing the search space; of this general structure: 1158 {'group/param_name': ('float/int/float_log/int_log', start, end), 1159 'group/param_name': ('categorical', [choices])}, e.g. 1160 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 1161 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}; 1162 metric : str, default f1 1163 the metric to maximize/minimize (see direction) 1164 n_trials : int, default 20 1165 the number of optimization trials to run 1166 best_n : int, default 1 1167 the number of epochs to average the metric; if 0, the last value is taken 1168 parameters_update : dict, optional 1169 the parameters update dictionary 1170 direction : {'maximize', 'minimize'} 1171 optimization direction 1172 load_episode : str, optional 1173 the name of the episode to load the model from 1174 load_epoch : int, optional 1175 the epoch to load the model from (if not provided, the last checkpoint is used) 1176 load_strict : bool, default True 1177 if `True`, the model will be loaded only if the parameters match exactly 1178 prune : bool, default False 1179 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1180 (with optuna HyperBand pruner) 1181 force : bool, default False 1182 if `True`, existing searches with the same name will be overwritten 1183 remove_saved_features : bool, default False 1184 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1185 1186 Returns 1187 ------- 1188 dict 1189 a dictionary of best parameters 1190 1191 """ 1192 self._check_search_validity(search_name, force=force) 1193 print(f"SEARCH {search_name}") 1194 self.remove_episode(f"_{search_name}") 1195 if parameters_update is None: 1196 parameters_update = {} 1197 parameters_update = self._update( 1198 parameters_update, {"general": {"metric_functions": {metric}}} 1199 ) 1200 parameters = self._make_parameters( 1201 f"_{search_name}", 1202 load_episode, 1203 parameters_update, 1204 parameters_update_second={"training": {"model_save_path": None}}, 1205 load_epoch=load_epoch, 1206 load_strict=load_strict, 1207 ) 1208 task = None 1209 1210 if prune: 1211 pruner = optuna.pruners.HyperbandPruner() 1212 else: 1213 pruner = optuna.pruners.NopPruner() 1214 study = optuna.create_study(direction=direction, pruner=pruner) 1215 runner = _Runner( 1216 search_space=search_space, 1217 load_episode=load_episode, 1218 load_epoch=load_epoch, 1219 metric=metric, 1220 average=best_n, 1221 task=task, 1222 remove_saved_features=remove_saved_features, 1223 project=self, 1224 search_name=search_name, 1225 ) 1226 study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials) 1227 if make_plots: 1228 search_path = self._search_path(search_name) 1229 os.mkdir(search_path) 1230 fig = optuna.visualization.plot_contour(study) 1231 plotly.offline.plot( 1232 fig, filename=os.path.join(search_path, f"{search_name}_contour.html") 1233 ) 1234 fig = optuna.visualization.plot_param_importances(study) 1235 plotly.offline.plot( 1236 fig, 1237 filename=os.path.join(search_path, f"{search_name}_importances.html"), 1238 ) 1239 best_params = study.best_params 1240 best_value = study.best_value 1241 if best_value == 0 or best_value == float("inf"): 1242 raise ValueError( 1243 f"Best metric value is {best_value}, check your partition method and make sure that all behaviors are present in the validation set!" 1244 ) 1245 self._save_search( 1246 search_name, 1247 parameters, 1248 n_trials, 1249 best_params, 1250 best_value, 1251 metric, 1252 search_space, 1253 ) 1254 self.remove_episode(f"_{search_name}") 1255 runner.clean() 1256 print(f"best parameters: {best_params}") 1257 print("\n") 1258 return best_params
Run an optuna hyperparameter search.
For a simpler function that fits most use cases, check out Project.run_default_hyperparameter_search().
To use a default search space with this method, import dlc2action.options.model_hyperparameters. It is
a dictionary where keys are model names and values are default search spaces.
The task parameters are read from the config files and updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from a previously run episode. For each trial, the objective metric is averaged over a few best epochs.
Parameters
search_name : str
    the name of the search to store it in the meta files and load in run_episode
search_space : dict
    a dictionary representing the search space; of this general structure:
    {'group/param_name': ('float/int/float_log/int_log', start, end),
    'group/param_name': ('categorical', [choices])}, e.g.
    {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
    'data/feature_extraction': ('categorical', ['kinematic', 'bones'])};
metric : str, default f1
    the metric to maximize/minimize (see direction)
n_trials : int, default 20
    the number of optimization trials to run
best_n : int, default 1
    the number of epochs to average the metric; if 0, the last value is taken
parameters_update : dict, optional
    the parameters update dictionary
direction : {'maximize', 'minimize'}
    optimization direction
load_episode : str, optional
    the name of the episode to load the model from
load_epoch : int, optional
    the epoch to load the model from (if not provided, the last checkpoint is used)
load_strict : bool, default True
    if True, the model will be loaded only if the parameters match exactly
prune : bool, default False
    if True, experiments where the optimized metric is improving too slowly will be terminated
    (with optuna HyperBand pruner)
force : bool, default False
    if True, existing searches with the same name will be overwritten
remove_saved_features : bool, default False
    if True, pre-computed features will be deleted after each run (if the data parameters change)
Returns
dict a dictionary of best parameters
1260 def run_prediction( 1261 self, 1262 prediction_name: str, 1263 episode_names: List, 1264 load_epochs: List = None, 1265 parameters_update: Dict = None, 1266 augment_n: int = 10, 1267 data_path: str = None, 1268 mode: str = "all", 1269 file_paths: Set = None, 1270 remove_saved_features: bool = False, 1271 frame_number_map_file: str = None, 1272 force: bool = False, 1273 embedding: bool = False, 1274 ) -> None: 1275 """Load models from previously run episodes to generate a prediction. 1276 1277 The probabilities predicted by the models are averaged. 1278 Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder 1279 under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level 1280 keys are the video ids, the second-level keys are the clip ids (like individual names) and the values 1281 are the prediction arrays. 1282 1283 Parameters 1284 ---------- 1285 prediction_name : str 1286 the name of the prediction 1287 episode_names : list 1288 a list of string episode names to load the models from 1289 load_epochs : list or int, optional 1290 a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes 1291 parameters_update : dict, optional 1292 a dictionary of parameter updates 1293 augment_n : int, default 10 1294 the number of augmentations to average over 1295 data_path : str, optional 1296 the data path to run the prediction for 1297 mode : {'all', 'test', 'val', 'train'} 1298 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1299 file_paths : set, optional 1300 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1301 for 1302 remove_saved_features : bool, default False 1303 if `True`, pre-computed features will be deleted 1304 submission : bool, default False 1305 if `True`, a MABe-22 style submission file is generated 1306 frame_number_map_file : str, optional 1307 path to the frame number map file 1308 force : bool, default False 1309 if `True`, existing prediction with this name will be overwritten 1310 embedding : bool, default False 1311 if `True`, the prediction is made for the embedding task 1312 1313 """ 1314 self._check_prediction_validity(prediction_name, force=force) 1315 print(f"PREDICTION {prediction_name}") 1316 task, parameters, mode, prediction, inference_time, behavior_dict = ( 1317 self._make_prediction( 1318 prediction_name, 1319 episode_names, 1320 load_epochs, 1321 parameters_update, 1322 data_path, 1323 file_paths, 1324 mode, 1325 augment_n, 1326 evaluate=False, 1327 embedding=embedding, 1328 ) 1329 ) 1330 predicted = task.dataset(mode).generate_full_length_prediction(prediction) 1331 1332 if remove_saved_features: 1333 self._remove_stores(parameters) 1334 1335 self._save_prediction( 1336 prediction_name, 1337 predicted, 1338 parameters, 1339 task, 1340 mode, 1341 embedding, 1342 inference_time, 1343 behavior_dict, 1344 ) 1345 print("\n")
Load models from previously run episodes to generate a prediction.
The probabilities predicted by the models are averaged.
Unless submission is True, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
are the prediction arrays.
Parameters
prediction_name : str
    the name of the prediction
episode_names : list
    a list of string episode names to load the models from
load_epochs : list or int, optional
    a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes
parameters_update : dict, optional
    a dictionary of parameter updates
augment_n : int, default 10
    the number of augmentations to average over
data_path : str, optional
    the data path to run the prediction for
mode : {'all', 'test', 'val', 'train'}
    the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
file_paths : set, optional
    a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
    for
remove_saved_features : bool, default False
    if True, pre-computed features will be deleted
submission : bool, default False
    if True, a MABe-22 style submission file is generated
frame_number_map_file : str, optional
    path to the frame number map file
force : bool, default False
    if True, existing prediction with this name will be overwritten
embedding : bool, default False
    if True, the prediction is made for the embedding task
1347 def evaluate_prediction( 1348 self, 1349 prediction_name: str, 1350 parameters_update: Dict = None, 1351 data_path: str = None, 1352 annotation_path: str = None, 1353 file_paths: Set = None, 1354 mode: str = None, 1355 remove_saved_features: bool = False, 1356 annotation_type: str = "none", 1357 num_classes: int = None, # Set when using data_path 1358 ) -> Tuple[float, dict]: 1359 """Make predictions and evaluate them 1360 inputs: 1361 prediction_name (str): the name of the prediction 1362 parameters_update (dict): a dictionary of parameter updates 1363 data_path (str): the data path to run the prediction for 1364 annotation_path (str): the annotation path to run the prediction for 1365 file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for 1366 mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1367 remove_saved_features (bool): if `True`, pre-computed features will be deleted 1368 annotation_type (str): the type of annotation to use for evaluation 1369 num_classes (int): the number of classes in the dataset, must be set with data_path 1370 outputs: 1371 results (dict): a dictionary of average values of metric functions 1372 """ 1373 1374 prediction_path = os.path.join( 1375 self.project_path, "results", "predictions", f"{prediction_name}" 1376 ) 1377 prediction_dict = {} 1378 for prediction_file_path in [ 1379 os.path.join(prediction_path, i) for i in os.listdir(prediction_path) 1380 ]: 1381 with open(os.path.join(prediction_file_path), "rb") as f: 1382 prediction = pickle.load(f) 1383 video_id = os.path.basename(prediction_file_path).split( 1384 "_" + prediction_name 1385 )[0] 1386 prediction_dict[video_id] = prediction 1387 if parameters_update is None: 1388 parameters_update = {} 1389 parameters_update = self._update( 1390 self._predictions().load_parameters(prediction_name), parameters_update 1391 ) 1392 parameters_update.pop("model") 1393 if not data_path is None: 1394 assert ( 1395 not num_classes is None 1396 ), "num_classes must be provided if data_path is provided" 1397 parameters_update["general"]["num_classes"] = num_classes + int( 1398 parameters_update["general"]["exclusive"] 1399 ) 1400 task, parameters, mode = self._make_task_prediction( 1401 "_", 1402 load_episode=None, 1403 parameters_update=parameters_update, 1404 data_path=data_path, 1405 annotation_path=annotation_path, 1406 file_paths=file_paths, 1407 mode=mode, 1408 annotation_type=annotation_type, 1409 ) 1410 results = task.evaluate_prediction(prediction_dict, data=mode) 1411 if remove_saved_features: 1412 self._remove_stores(parameters) 1413 results = Project._reformat_results( 1414 results[1], 1415 task.behaviors_dict(), 1416 exclusive=task.general_parameters["exclusive"], 1417 ) 1418 return results
Make predictions and evaluate them
inputs:
    prediction_name (str): the name of the prediction
    parameters_update (dict): a dictionary of parameter updates
    data_path (str): the data path to run the prediction for
    annotation_path (str): the annotation path to run the prediction for
    file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for
    mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
    remove_saved_features (bool): if True, pre-computed features will be deleted
    annotation_type (str): the type of annotation to use for evaluation
    num_classes (int): the number of classes in the dataset, must be set with data_path
outputs:
    results (dict): a dictionary of average values of metric functions
1420 def evaluate( 1421 self, 1422 episode_names: List, 1423 load_epochs: List = None, 1424 augment_n: int = 0, 1425 data_path: str = None, 1426 file_paths: Set = None, 1427 mode: str = None, 1428 parameters_update: Dict = None, 1429 multiple_episode_policy: str = "average", 1430 remove_saved_features: bool = False, 1431 skip_updating_meta: bool = True, 1432 annotation_type: str = "none", 1433 ) -> Dict: 1434 """Load one or several models from previously run episodes to make an evaluation. 1435 1436 By default it will run on the test (or validation, if there is no test) subset of the project dataset. 1437 1438 Parameters 1439 ---------- 1440 episode_names : list 1441 a list of string episode names to load the models from 1442 load_epochs : list, optional 1443 a list of integer epoch indices to load the model from; if None, the last ones are used 1444 augment_n : int, default 0 1445 the number of augmentations to average over 1446 data_path : str, optional 1447 the data path to run the prediction for 1448 file_paths : set, optional 1449 a set of files to run the prediction for 1450 mode : {'test', 'val', 'train', 'all'} 1451 the subset of the data to make the prediction for (forced to 'all' if data_path is not None; 1452 by default 'test' if test subset is not empty and 'val' otherwise) 1453 parameters_update : dict, optional 1454 a dictionary with parameter updates (cannot change model parameters) 1455 multiple_episode_policy : {'average', 'statistics'} 1456 the policy to use when multiple episodes are provided 1457 remove_saved_features : bool, default False 1458 if `True`, the dataset will be deleted 1459 skip_updating_meta : bool, default True 1460 if `True`, the meta file will not be updated with the computed metrics 1461 1462 Returns 1463 ------- 1464 metric : dict 1465 a dictionary of average values of metric functions 1466 1467 """ 1468 names = [] 1469 for episode_name in episode_names: 1470 names += self._episodes().get_runs(episode_name) 1471 if len(set(episode_names)) == 1: 1472 print(f"EVALUATION {episode_names[0]}") 1473 else: 1474 print(f"EVALUATION {episode_names}") 1475 if len(names) > 1: 1476 evaluate = True 1477 else: 1478 evaluate = False 1479 if multiple_episode_policy == "average": 1480 task, parameters, mode, prediction, inference_time, behavior_dict = ( 1481 self._make_prediction( 1482 "_", 1483 episode_names, 1484 load_epochs, 1485 parameters_update, 1486 mode=mode, 1487 data_path=data_path, 1488 file_paths=file_paths, 1489 augment_n=augment_n, 1490 evaluate=evaluate, 1491 annotation_type=annotation_type, 1492 ) 1493 ) 1494 print("EVALUATE PREDICTION:") 1495 indices = [ 1496 list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict)) 1497 ] 1498 _, results = task.evaluate_prediction( 1499 prediction, data=mode, indices=indices 1500 ) 1501 if len(names) == 1 and mode == "val" and not skip_updating_meta: 1502 self._update_episode_metrics(names[0], results) 1503 results = Project._reformat_results( 1504 results, 1505 behavior_dict, 1506 exclusive=task.general_parameters["exclusive"], 1507 ) 1508 1509 elif multiple_episode_policy == "statistics": 1510 values = defaultdict(lambda: []) 1511 task = None 1512 for name in names: 1513 ( 1514 task, 1515 parameters, 1516 mode, 1517 prediction, 1518 inference_time, 1519 behavior_dict, 1520 ) = self._make_prediction( 1521 "_", 1522 [name], 1523 load_epochs, 1524 parameters_update, 1525 mode=mode, 1526 data_path=data_path, 1527 file_paths=file_paths, 1528 augment_n=augment_n, 1529 evaluate=evaluate, 1530 task=task, 1531 ) 1532 _, metrics = task.evaluate_prediction( 1533 prediction, data=mode, indices=list(behavior_dict.keys()) 1534 ) 1535 for name, value in metrics.items(): 1536 values[name].append(value) 1537 if mode == "val" and not skip_updating_meta: 1538 self._update_episode_metrics(name, metrics) 1539 results = defaultdict(lambda: {}) 1540 mean_string = "" 1541 std_string = "" 1542 for key, value_list in values.items(): 1543 results[key]["mean"] = np.mean(value_list) 1544 results[key]["std"] = np.std(value_list) 1545 results[key]["all"] = value_list 1546 mean_string += f"{key} {np.mean(value_list):.3f}, " 1547 std_string += f"{key} {np.std(value_list):.3f}, " 1548 print("MEAN:") 1549 print(mean_string) 1550 print("STD:") 1551 print(std_string) 1552 else: 1553 raise ValueError( 1554 f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose " 1555 f"from ['average', 'statistics']" 1556 ) 1557 if len(names) > 0 and remove_saved_features: 1558 self._remove_stores(parameters) 1559 print(f"Inference time: {inference_time}") 1560 print("\n") 1561 return results
Load one or several models from previously run episodes to make an evaluation.
By default it will run on the test (or validation, if there is no test) subset of the project dataset.
Parameters
episode_names : list
    a list of string episode names to load the models from
load_epochs : list, optional
    a list of integer epoch indices to load the model from; if None, the last ones are used
augment_n : int, default 0
    the number of augmentations to average over
data_path : str, optional
    the data path to run the prediction for
file_paths : set, optional
    a set of files to run the prediction for
mode : {'test', 'val', 'train', 'all'}
    the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
    by default 'test' if test subset is not empty and 'val' otherwise)
parameters_update : dict, optional
    a dictionary with parameter updates (cannot change model parameters)
multiple_episode_policy : {'average', 'statistics'}
    the policy to use when multiple episodes are provided
remove_saved_features : bool, default False
    if True, the dataset will be deleted
skip_updating_meta : bool, default True
    if True, the meta file will not be updated with the computed metrics
Returns
metric : dict a dictionary of average values of metric functions
1563 def run_suggestion( 1564 self, 1565 suggestions_name: str, 1566 error_episode: str = None, 1567 error_load_epoch: int = None, 1568 error_class: str = None, 1569 suggestions_prediction: str = None, 1570 suggestion_episodes: List = [None], 1571 suggestion_load_epoch: int = None, 1572 suggestion_classes: List = None, 1573 error_threshold: float = 0.5, 1574 error_threshold_diff: float = 0.1, 1575 error_hysteresis: bool = False, 1576 suggestion_threshold: Union[float, List] = 0.5, 1577 suggestion_threshold_diff: Union[float, List] = 0.1, 1578 suggestion_hysteresis: Union[bool, List] = True, 1579 min_frames_suggestion: int = 10, 1580 min_frames_al: int = 30, 1581 visibility_min_score: float = 0, 1582 visibility_min_frac: float = 0.7, 1583 augment_n: int = 0, 1584 exclude_classes: List = None, 1585 exclude_threshold: Union[float, List] = 0.6, 1586 exclude_threshold_diff: Union[float, List] = 0.1, 1587 exclude_hysteresis: Union[bool, List] = False, 1588 include_classes: List = None, 1589 include_threshold: Union[float, List] = 0.4, 1590 include_threshold_diff: Union[float, List] = 0.1, 1591 include_hysteresis: Union[bool, List] = False, 1592 data_path: str = None, 1593 file_paths: Set = None, 1594 parameters_update: Dict = None, 1595 mode: str = "all", 1596 force: bool = False, 1597 remove_saved_features: bool = False, 1598 cut_annotated: bool = False, 1599 background_threshold: float = None, 1600 ) -> None: 1601 """Create active learning and suggestion files. 1602 1603 Generate predictions with the error and suggestion model and use them to create 1604 suggestion files for the labeling interface. Those files will render as suggested labels 1605 at intervals with high pose estimation quality. Quality here is defined by probability of error 1606 (predicted by the error model) and visibility parameters. 1607 1608 If `error_episode` or `exclude_classes` is not `None`, 1609 an active learning file will be created as well (with frames with high predicted probability of classes 1610 from `exclude_classes` and/or errors excluded from the active learning intervals). 1611 1612 In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals) 1613 you can apply one of three methods. 1614 1615 - **Simple threshold** 1616 1617 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold` 1618 parameter to $\alpha$. 1619 In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will 1620 be considered labeled. 1621 1622 - **Hysteresis threshold** 1623 1624 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold` 1625 parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$. 1626 Now intervals will be marked with a label if the probability of that label for all frames is higher 1627 than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$. 1628 1629 - **Max hysteresis threshold** 1630 1631 Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold` 1632 parameter to $\alpha$ and the `threshold_diff` parameter to `None`. 1633 With this combination intervals are marked with a label if that label is more likely than any other 1634 for all frames in this interval and at for at least one of those frames its probability is higher than 1635 $\alpha$. 1636 1637 Parameters 1638 ---------- 1639 suggestions_name : str 1640 the name of the suggestions 1641 error_episode : str, optional 1642 the name of the episode where the error model should be loaded from 1643 error_load_epoch : int, optional 1644 the epoch the error model should be loaded from 1645 error_class : str, optional 1646 the name of the error class (in `error_episode`) 1647 suggestions_prediction : str, optional 1648 the name of the predictions that should be used for the suggestion model 1649 suggestion_episodes : list, optional 1650 the names of the episodes where the suggestion models should be loaded from 1651 suggestion_load_epoch : int, optional 1652 the epoch the suggestion model should be loaded from 1653 suggestion_classes : list, optional 1654 a list of string names of the classes that should be suggested (in `suggestion_episode`) 1655 error_threshold : float, default 0.5 1656 the hard threshold for error prediction 1657 error_threshold_diff : float, default 0.1 1658 the difference between soft and hard thresholds for error prediction (in case hysteresis is used) 1659 error_hysteresis : bool, default False 1660 if True, hysteresis is used for error prediction 1661 suggestion_threshold : float | list, default 0.5 1662 the hard threshold for class prediction (use a list to set different rules for different classes) 1663 suggestion_threshold_diff : float | list, default 0.1 1664 the difference between soft and hard thresholds for class prediction (in case hysteresis is used; 1665 use a list to set different rules for different classes) 1666 suggestion_hysteresis : bool | list, default True 1667 if True, hysteresis is used for class prediction (use a list to set different rules for different classes) 1668 min_frames_suggestion : int, default 10 1669 only actions longer than this number of frames will be suggested 1670 min_frames_al : int, default 30 1671 only active learning intervals longer than this number of frames will be suggested 1672 visibility_min_score : float, default 0 1673 the minimum visibility score for visibility filtering 1674 visibility_min_frac : float, default 0.7 1675 the minimum fraction of visible frames for visibility filtering 1676 augment_n : int, default 10 1677 the number of augmentations to average the predictions over 1678 exclude_classes : list, optional 1679 a list of string names of classes that should be excluded from the active learning intervals 1680 exclude_threshold : float | list, default 0.6 1681 the hard threshold for excluded class prediction (use a list to set different rules for different classes) 1682 exclude_threshold_diff : float | list, default 0.1 1683 the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used) 1684 exclude_hysteresis : bool | list, default False 1685 if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes) 1686 include_classes : list, optional 1687 a list of string names of classes that should be included into the active learning intervals 1688 include_threshold : float | list, default 0.6 1689 the hard threshold for included class prediction (use a list to set different rules for different classes) 1690 include_threshold_diff : float | list, default 0.1 1691 the difference between soft and hard thresholds for included class prediction (in case hysteresis is used) 1692 include_hysteresis : bool | list, default False 1693 if True, hysteresis is used for included class prediction (use a list to set different rules for different classes) 1694 data_path : str, optional 1695 the data path to run the prediction for 1696 file_paths : set, optional 1697 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1698 for 1699 parameters_update : dict, optional 1700 the parameters update dictionary 1701 mode : {'all', 'test', 'val', 'train'} 1702 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1703 force : bool, default False 1704 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 1705 remove_saved_features : bool, default False 1706 if `True`, the dataset will be deleted. 1707 cut_annotated : bool, default False 1708 if `True`, annotated frames will be cut from the suggestions 1709 background_threshold : float, default 0.5 1710 the threshold for background prediction 1711 1712 """ 1713 self._check_suggestions_validity(suggestions_name, force=force) 1714 if any([x is None for x in suggestion_episodes]): 1715 suggestion_episodes = None 1716 if error_episode is None and ( 1717 suggestion_episodes is None and suggestions_prediction is None 1718 ): 1719 raise ValueError( 1720 "Both error_episode and suggestion_episode parameters cannot be None at the same time" 1721 ) 1722 print(f"SUGGESTION {suggestions_name}") 1723 task = None 1724 if suggestion_classes is None: 1725 suggestion_classes = [] 1726 if exclude_classes is None: 1727 exclude_classes = [] 1728 if include_classes is None: 1729 include_classes = [] 1730 if isinstance(suggestion_threshold, list): 1731 if len(suggestion_threshold) != len(suggestion_classes): 1732 raise ValueError( 1733 "The suggestion_threshold parameter has to be either a float value or a list of " 1734 f"float values of the same length as suggestion_classes (got a list of length " 1735 f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)" 1736 ) 1737 else: 1738 suggestion_threshold = [suggestion_threshold for _ in suggestion_classes] 1739 if isinstance(suggestion_threshold_diff, list): 1740 if len(suggestion_threshold_diff) != len(suggestion_classes): 1741 raise ValueError( 1742 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1743 f"float values of the same length as suggestion_classes (got a list of length " 1744 f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)" 1745 ) 1746 else: 1747 suggestion_threshold_diff = [ 1748 suggestion_threshold_diff for _ in suggestion_classes 1749 ] 1750 if isinstance(suggestion_hysteresis, list): 1751 if len(suggestion_hysteresis) != len(suggestion_classes): 1752 raise ValueError( 1753 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1754 f"float values of the same length as suggestion_classes (got a list of length " 1755 f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)" 1756 ) 1757 else: 1758 suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes] 1759 if isinstance(exclude_threshold, list): 1760 if len(exclude_threshold) != len(exclude_classes): 1761 raise ValueError( 1762 "The exclude_threshold parameter has to be either a float value or a list of " 1763 f"float values of the same length as exclude_classes (got a list of length " 1764 f"{len(exclude_threshold)} for {len(exclude_classes)} classes)" 1765 ) 1766 else: 1767 exclude_threshold = [exclude_threshold for _ in exclude_classes] 1768 if isinstance(exclude_threshold_diff, list): 1769 if len(exclude_threshold_diff) != len(exclude_classes): 1770 raise ValueError( 1771 "The exclude_threshold_diff parameter has to be either a float value or a list of " 1772 f"float values of the same length as exclude_classes (got a list of length " 1773 f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)" 1774 ) 1775 else: 1776 exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes] 1777 if isinstance(exclude_hysteresis, list): 1778 if len(exclude_hysteresis) != len(exclude_classes): 1779 raise ValueError( 1780 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1781 f"float values of the same length as suggestion_classes (got a list of length " 1782 f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)" 1783 ) 1784 else: 1785 exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes] 1786 if isinstance(include_threshold, list): 1787 if len(include_threshold) != len(include_classes): 1788 raise ValueError( 1789 "The exclude_threshold parameter has to be either a float value or a list of " 1790 f"float values of the same length as exclude_classes (got a list of length " 1791 f"{len(include_threshold)} for {len(include_classes)} classes)" 1792 ) 1793 else: 1794 include_threshold = [include_threshold for _ in include_classes] 1795 if isinstance(include_threshold_diff, list): 1796 if len(include_threshold_diff) != len(include_classes): 1797 raise ValueError( 1798 "The exclude_threshold_diff parameter has to be either a float value or a list of " 1799 f"float values of the same length as exclude_classes (got a list of length " 1800 f"{len(include_threshold_diff)} for {len(include_classes)} classes)" 1801 ) 1802 else: 1803 include_threshold_diff = [include_threshold_diff for _ in include_classes] 1804 if isinstance(include_hysteresis, list): 1805 if len(include_hysteresis) != len(include_classes): 1806 raise ValueError( 1807 "The suggestion_threshold_diff parameter has to be either a float value or a list of " 1808 f"float values of the same length as suggestion_classes (got a list of length " 1809 f"{len(include_hysteresis)} for {len(include_classes)} classes)" 1810 ) 1811 else: 1812 include_hysteresis = [include_hysteresis for _ in include_classes] 1813 if (suggestion_episodes is None and suggestions_prediction is None) and len( 1814 exclude_classes 1815 ) > 0: 1816 raise ValueError( 1817 "In order to exclude classes from the active learning intervals you need to set the " 1818 "suggestion_episode parameter" 1819 ) 1820 1821 task = None 1822 if error_episode is not None: 1823 task, parameters, mode = self._make_task_prediction( 1824 prediction_name=suggestions_name, 1825 load_episode=error_episode, 1826 parameters_update=parameters_update, 1827 load_epoch=error_load_epoch, 1828 data_path=data_path, 1829 mode=mode, 1830 file_paths=file_paths, 1831 task=task, 1832 ) 1833 predicted_error = task.predict( 1834 data=mode, 1835 raw_output=True, 1836 apply_primary_function=True, 1837 augment_n=augment_n, 1838 ) 1839 else: 1840 predicted_error = None 1841 1842 if suggestion_episodes is not None: 1843 ( 1844 task, 1845 parameters, 1846 mode, 1847 predicted_classes, 1848 inference_time, 1849 behavior_dict, 1850 ) = self._make_prediction( 1851 prediction_name=suggestions_name, 1852 episode_names=suggestion_episodes, 1853 load_epochs=suggestion_load_epoch, 1854 parameters_update=parameters_update, 1855 data_path=data_path, 1856 file_paths=file_paths, 1857 mode=mode, 1858 task=task, 1859 ) 1860 elif suggestions_prediction is not None: 1861 with open( 1862 os.path.join( 1863 self.project_path, 1864 "results", 1865 "predictions", 1866 f"{suggestions_prediction}.pickle", 1867 ), 1868 "rb", 1869 ) as f: 1870 predicted_classes = pickle.load(f) 1871 if parameters_update is None: 1872 parameters_update = {} 1873 parameters_update = self._update( 1874 self._predictions().load_parameters(suggestions_prediction), 1875 parameters_update, 1876 ) 1877 parameters_update.pop("model") 1878 if suggestion_episodes is None: 1879 suggestion_episodes = [ 1880 os.path.basename( 1881 os.path.dirname( 1882 parameters_update["training"]["checkpoint_path"] 1883 ) 1884 ) 1885 ] 1886 task, parameters, mode = self._make_task_prediction( 1887 "_", 1888 load_episode=None, 1889 parameters_update=parameters_update, 1890 data_path=data_path, 1891 file_paths=file_paths, 1892 mode=mode, 1893 ) 1894 else: 1895 predicted_classes = None 1896 1897 if len(suggestion_classes) > 0 and predicted_classes is not None: 1898 suggestions = self._make_suggestions( 1899 task, 1900 predicted_error, 1901 predicted_classes, 1902 suggestion_threshold, 1903 suggestion_threshold_diff, 1904 suggestion_hysteresis, 1905 suggestion_episodes, 1906 suggestion_classes, 1907 error_threshold, 1908 min_frames_suggestion, 1909 min_frames_al, 1910 visibility_min_score, 1911 visibility_min_frac, 1912 cut_annotated=cut_annotated, 1913 ) 1914 videos = list(suggestions.keys()) 1915 for v_id in videos: 1916 times_dict = defaultdict(lambda: defaultdict(lambda: [])) 1917 clips = set() 1918 for c in suggestions[v_id]: 1919 for start, end, ind in suggestions[v_id][c]: 1920 times_dict[ind][c].append([start, end, 2]) 1921 clips.add(ind) 1922 clips = list(clips) 1923 times_dict = dict(times_dict) 1924 times = [ 1925 [times_dict[ind][c] for c in suggestion_classes] for ind in clips 1926 ] 1927 save_path = self._suggestion_path(v_id, suggestions_name) 1928 Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True) 1929 with open(save_path, "wb") as f: 1930 pickle.dump((None, suggestion_classes, clips, times), f) 1931 1932 if ( 1933 error_episode is not None 1934 or len(exclude_classes) > 0 1935 or len(include_classes) > 0 1936 ): 1937 al_points = self._make_al_points( 1938 task, 1939 predicted_error, 1940 predicted_classes, 1941 exclude_classes, 1942 exclude_threshold, 1943 exclude_threshold_diff, 1944 exclude_hysteresis, 1945 include_classes, 1946 include_threshold, 1947 include_threshold_diff, 1948 include_hysteresis, 1949 error_episode, 1950 error_class, 1951 suggestion_episodes, 1952 error_threshold, 1953 error_threshold_diff, 1954 error_hysteresis, 1955 min_frames_al, 1956 visibility_min_score, 1957 visibility_min_frac, 1958 ) 1959 else: 1960 al_points = self._make_al_points_from_suggestions( 1961 suggestions_name, 1962 task, 1963 predicted_classes, 1964 background_threshold, 1965 visibility_min_score, 1966 visibility_min_frac, 1967 num_behaviors=len(task.behaviors_dict()), 1968 ) 1969 save_path = self._al_points_path(suggestions_name) 1970 Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True) 1971 with open(save_path, "wb") as f: 1972 pickle.dump(al_points, f) 1973 1974 meta_parameters = { 1975 "error_episode": error_episode, 1976 "error_load_epoch": error_load_epoch, 1977 "error_class": error_class, 1978 "suggestion_episode": suggestion_episodes, 1979 "suggestion_load_epoch": suggestion_load_epoch, 1980 "suggestion_classes": suggestion_classes, 1981 "error_threshold": error_threshold, 1982 "error_threshold_diff": error_threshold_diff, 1983 "error_hysteresis": error_hysteresis, 1984 "suggestion_threshold": suggestion_threshold, 1985 "suggestion_threshold_diff": suggestion_threshold_diff, 1986 "suggestion_hysteresis": suggestion_hysteresis, 1987 "min_frames_suggestion": min_frames_suggestion, 1988 "min_frames_al": min_frames_al, 1989 "visibility_min_score": visibility_min_score, 1990 "visibility_min_frac": visibility_min_frac, 1991 "augment_n": augment_n, 1992 "exclude_classes": exclude_classes, 1993 "exclude_threshold": exclude_threshold, 1994 "exclude_threshold_diff": exclude_threshold_diff, 1995 "exclude_hysteresis": exclude_hysteresis, 1996 } 1997 self._save_suggestions(suggestions_name, {}, meta_parameters) 1998 if data_path is not None or file_paths is not None or remove_saved_features: 1999 self._remove_stores(parameters) 2000 print(f"\n")
Create active learning and suggestion files.
Generate predictions with the error and suggestion model and use them to create suggestion files for the labeling interface. Those files will render as suggested labels at intervals with high pose estimation quality. Quality here is defined by probability of error (predicted by the error model) and visibility parameters.
If error_episode or exclude_classes is not None,
an active learning file will be created as well (with frames with high predicted probability of classes
from exclude_classes and/or errors excluded from the active learning intervals).
In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals) you can apply one of three methods.
- Simple threshold - Set the - hysteresisparameter (e.g.- error_hysteresis) to- Falseand the- thresholdparameter to $lpha$. In this case if the probability of a label is predicted to be higher than $lpha$ the frame will be considered labeled.
- Hysteresis threshold - Set the - hysteresisparameter (e.g.- error_hysteresis) to- True, the- thresholdparameter to $lpha$ and the- threshold_diffparameter to $eta$. Now intervals will be marked with a label if the probability of that label for all frames is higher than $lpha - eta$ and at least for one frame in that interval it is higher than $lpha$.
- Max hysteresis threshold - Set the - hysteresisparameter (e.g.- error_hysteresis) to- True, the- thresholdparameter to $lpha$ and the- threshold_diffparameter to- None. With this combination intervals are marked with a label if that label is more likely than any other for all frames in this interval and at for at least one of those frames its probability is higher than $lpha$.
Parameters
suggestions_name : str
    the name of the suggestions
error_episode : str, optional
    the name of the episode where the error model should be loaded from
error_load_epoch : int, optional
    the epoch the error model should be loaded from
error_class : str, optional
    the name of the error class (in error_episode)
suggestions_prediction : str, optional
    the name of the predictions that should be used for the suggestion model
suggestion_episodes : list, optional
    the names of the episodes where the suggestion models should be loaded from
suggestion_load_epoch : int, optional
    the epoch the suggestion model should be loaded from
suggestion_classes : list, optional
    a list of string names of the classes that should be suggested (in suggestion_episode)
error_threshold : float, default 0.5
    the hard threshold for error prediction
error_threshold_diff : float, default 0.1
    the difference between soft and hard thresholds for error prediction (in case hysteresis is used)
error_hysteresis : bool, default False
    if True, hysteresis is used for error prediction
suggestion_threshold : float | list, default 0.5
    the hard threshold for class prediction (use a list to set different rules for different classes)
suggestion_threshold_diff : float | list, default 0.1
    the difference between soft and hard thresholds for class prediction (in case hysteresis is used;
    use a list to set different rules for different classes)
suggestion_hysteresis : bool | list, default True
    if True, hysteresis is used for class prediction (use a list to set different rules for different classes)
min_frames_suggestion : int, default 10
    only actions longer than this number of frames will be suggested
min_frames_al : int, default 30
    only active learning intervals longer than this number of frames will be suggested
visibility_min_score : float, default 0
    the minimum visibility score for visibility filtering
visibility_min_frac : float, default 0.7
    the minimum fraction of visible frames for visibility filtering
augment_n : int, default 10
    the number of augmentations to average the predictions over
exclude_classes : list, optional
    a list of string names of classes that should be excluded from the active learning intervals
exclude_threshold : float | list, default 0.6
    the hard threshold for excluded class prediction (use a list to set different rules for different classes)
exclude_threshold_diff : float | list, default 0.1
    the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used)
exclude_hysteresis : bool | list, default False
    if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes)
include_classes : list, optional
    a list of string names of classes that should be included into the active learning intervals
include_threshold : float | list, default 0.6
    the hard threshold for included class prediction (use a list to set different rules for different classes)
include_threshold_diff : float | list, default 0.1
    the difference between soft and hard thresholds for included class prediction (in case hysteresis is used)
include_hysteresis : bool | list, default False
    if True, hysteresis is used for included class prediction (use a list to set different rules for different classes)
data_path : str, optional
    the data path to run the prediction for
file_paths : set, optional
    a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
    for
parameters_update : dict, optional
    the parameters update dictionary
mode : {'all', 'test', 'val', 'train'}
    the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
force : bool, default False
    if True and an episode with name episode_name already exists, it will be overwritten (use with caution!)
remove_saved_features : bool, default False
    if True, the dataset will be deleted.
cut_annotated : bool, default False
    if True, annotated frames will be cut from the suggestions
background_threshold : float, default 0.5
    the threshold for background prediction
2095 def suggest_intervals_with_similarity( 2096 self, 2097 suggestions_name: str, 2098 prediction_name: str, 2099 target_video_id: str, 2100 target_clip: str, 2101 target_start: int, 2102 target_end: int, 2103 min_length: int = 60, 2104 n_intervals: int = 5, 2105 force: bool = False, 2106 ): 2107 """ 2108 Suggest intervals based on similarity to a target interval. 2109 2110 Parameters 2111 ---------- 2112 suggestions_name : str 2113 Name of the suggestion. 2114 prediction_name : str 2115 Name of the prediction to use. 2116 target_video_id : str 2117 Video id of the target interval. 2118 target_clip : str 2119 Clip id of the target interval. 2120 target_start : int 2121 Start frame of the target interval. 2122 target_end : int 2123 End frame of the target interval. 2124 min_length : int, default 60 2125 Minimum length of the suggested intervals. 2126 n_intervals : int, default 5 2127 Number of suggested intervals. 2128 force : bool, default False 2129 If True, the suggestion is overwritten if it already exists. 2130 2131 """ 2132 self._check_suggestions_validity(suggestions_name, force=force) 2133 print(f"SUGGESTION {suggestions_name}") 2134 score_dict = self._generate_similarity_score( 2135 prediction_name, target_video_id, target_clip, target_start, target_end 2136 ) 2137 intervals = self._suggest_intervals_from_dict( 2138 score_dict, min_length, n_intervals 2139 ) 2140 suggestions_path = os.path.join( 2141 self.project_path, 2142 "results", 2143 "suggestions", 2144 suggestions_name, 2145 ) 2146 if not os.path.exists(suggestions_path): 2147 os.mkdir(suggestions_path) 2148 with open( 2149 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2150 ) as f: 2151 pickle.dump(intervals, f) 2152 meta_parameters = { 2153 "prediction_name": prediction_name, 2154 "min_frames_suggestion": min_length, 2155 "n_intervals": n_intervals, 2156 "target_clip": target_clip, 2157 "target_start": target_start, 2158 "target_end": target_end, 2159 } 2160 self._save_suggestions(suggestions_name, {}, meta_parameters) 2161 print("\n")
Suggest intervals based on similarity to a target interval.
Parameters
suggestions_name : str Name of the suggestion. prediction_name : str Name of the prediction to use. target_video_id : str Video id of the target interval. target_clip : str Clip id of the target interval. target_start : int Start frame of the target interval. target_end : int End frame of the target interval. min_length : int, default 60 Minimum length of the suggested intervals. n_intervals : int, default 5 Number of suggested intervals. force : bool, default False If True, the suggestion is overwritten if it already exists.
2163 def suggest_intervals_with_uncertainty( 2164 self, 2165 suggestions_name: str, 2166 episode_names: List, 2167 load_epochs: List = None, 2168 classes: List = None, 2169 n_frames: int = 10000, 2170 method: str = "least_confidence", 2171 min_length: int = 60, 2172 augment_n: int = 0, 2173 data_path: str = None, 2174 file_paths: Set = None, 2175 parameters_update: Dict = None, 2176 mode: str = "all", 2177 force: bool = False, 2178 remove_saved_features: bool = False, 2179 ) -> None: 2180 """Generate an active learning file based on model uncertainty. 2181 2182 If you provide several episode names, the predicted probabilities will be averaged. 2183 2184 Parameters 2185 ---------- 2186 suggestions_name : str 2187 the name of the suggestion 2188 episode_names : list 2189 a list of string episode names to load the models from 2190 load_epochs : list, optional 2191 a list of epoch indices to load the models from (if `None`, the last ones will be used) 2192 classes : list, optional 2193 a list of classes to look at (by default all) 2194 n_frames : int, default 10000 2195 the threshold total number of frames in the suggested intervals (in the end result it will most likely 2196 be slightly larger; it will only be smaller if the algorithm fails to find enough intervals 2197 with the set parameters) 2198 method : {"least_confidence", "entropy"} 2199 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 2200 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 2201 min_length : int, default 60 2202 the minimum number of frames in one interval 2203 augment_n : int, default 0 2204 the number of augmentations to average the predictions over 2205 data_path : str, optional 2206 the path to a data folder (by default, the project data is used) 2207 file_paths : set, optional 2208 a list of file paths (by default, the project data is used) 2209 parameters_update : dict, optional 2210 a dictionary of parameter updates 2211 mode : {"test", "val", "train", "all"} 2212 the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`; 2213 by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise) 2214 force : bool, default False 2215 if `True`, existing suggestions with the same name will be overwritten 2216 remove_saved_features : bool, default False 2217 if `True`, the dataset will be deleted after the computation 2218 2219 """ 2220 self._check_suggestions_validity(suggestions_name, force=force) 2221 print(f"SUGGESTION {suggestions_name}") 2222 task, parameters, mode, predicted, inference_time, behavior_dict = ( 2223 self._make_prediction( 2224 suggestions_name, 2225 episode_names, 2226 load_epochs, 2227 parameters_update, 2228 data_path=data_path, 2229 file_paths=file_paths, 2230 mode=mode, 2231 augment_n=augment_n, 2232 evaluate=False, 2233 ) 2234 ) 2235 if classes is None: 2236 classes = behavior_dict.values() 2237 episode = self._episodes().get_runs(episode_names[0])[0] 2238 score_tensors = task.generate_uncertainty_score( 2239 classes, 2240 augment_n, 2241 method, 2242 predicted, 2243 self._episode(episode).get_behaviors_dict(), 2244 ) 2245 intervals = self._suggest_intervals( 2246 task.dataset(mode), score_tensors, n_frames, min_length 2247 ) 2248 for k, v in intervals.items(): 2249 l = sum([x[1] - x[0] for x in v]) 2250 print(f"{k}: {len(v)} ({l})") 2251 if remove_saved_features: 2252 self._remove_stores(parameters) 2253 suggestions_path = os.path.join( 2254 self.project_path, 2255 "results", 2256 "suggestions", 2257 suggestions_name, 2258 ) 2259 if not os.path.exists(suggestions_path): 2260 os.mkdir(suggestions_path) 2261 with open( 2262 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2263 ) as f: 2264 pickle.dump(intervals, f) 2265 meta_parameters = { 2266 "suggestion_episode": episode_names, 2267 "suggestion_load_epoch": load_epochs, 2268 "suggestion_classes": classes, 2269 "min_frames_suggestion": min_length, 2270 "augment_n": augment_n, 2271 "method": method, 2272 "num_frames": n_frames, 2273 } 2274 self._save_suggestions(suggestions_name, {}, meta_parameters) 2275 print("\n")
Generate an active learning file based on model uncertainty.
If you provide several episode names, the predicted probabilities will be averaged.
Parameters
suggestions_name : str
    the name of the suggestion
episode_names : list
    a list of string episode names to load the models from
load_epochs : list, optional
    a list of epoch indices to load the models from (if None, the last ones will be used)
classes : list, optional
    a list of classes to look at (by default all)
n_frames : int, default 10000
    the threshold total number of frames in the suggested intervals (in the end result it will most likely
    be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
    with the set parameters)
method : {"least_confidence", "entropy"}
    the method used to calculate the scores from the probability predictions ("least_confidence": 1 - p_i if
    p_i > 0.5 or p_i if p_i < 0.5; "entropy": - p_i * log(p_i) - (1 - p_i) * log(1 - p_i))
min_length : int, default 60
    the minimum number of frames in one interval
augment_n : int, default 0
    the number of augmentations to average the predictions over
data_path : str, optional
    the path to a data folder (by default, the project data is used)
file_paths : set, optional
    a list of file paths (by default, the project data is used)
parameters_update : dict, optional
    a dictionary of parameter updates
mode : {"test", "val", "train", "all"}
    the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
    by default set to 'test' if the test subset if not empty, or to 'val' otherwise)
force : bool, default False
    if True, existing suggestions with the same name will be overwritten
remove_saved_features : bool, default False
    if True, the dataset will be deleted after the computation
2277 def suggest_intervals_with_bald( 2278 self, 2279 suggestions_name: str, 2280 episode_name: str, 2281 load_epoch: int = None, 2282 classes: List = None, 2283 n_frames: int = 10000, 2284 num_models: int = 10, 2285 kernel_size: int = 11, 2286 min_length: int = 60, 2287 augment_n: int = 0, 2288 data_path: str = None, 2289 file_paths: Set = None, 2290 parameters_update: Dict = None, 2291 mode: str = "all", 2292 force: bool = False, 2293 remove_saved_features: bool = False, 2294 ): 2295 """Generate an active learning file based on Bayesian Active Learning by Disagreement. 2296 2297 Parameters 2298 ---------- 2299 suggestions_name : str 2300 the name of the suggestion 2301 episode_name : str 2302 the name of the episode to load the model from 2303 load_epoch : int, optional 2304 the index of the epoch to load the model from (if `None`, the last one will be used) 2305 classes : list, optional 2306 a list of classes to look at (by default all) 2307 n_frames : int, default 10000 2308 the threshold total number of frames in the suggested intervals (in the end result it will most likely 2309 be slightly larger; it will only be smaller if the algorithm fails to find enough intervals 2310 with the set parameters) 2311 num_models : int, default 10 2312 the number of dropout masks to apply 2313 kernel_size : int, default 11 2314 the size of the smoothing kernel applied to the discrete results 2315 min_length : int, default 60 2316 the minimum number of frames in one interval 2317 augment_n : int, default 0 2318 the number of augmentations to average the predictions over 2319 data_path : str, optional 2320 the path to a data folder (by default, the project data is used) 2321 file_paths : set, optional 2322 a list of file paths (by default, the project data is used) 2323 parameters_update : dict, optional 2324 a dictionary of parameter updates 2325 mode : {"test", "val", "train", "all"} 2326 the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`; 2327 by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise) 2328 force : bool, default False 2329 if `True`, existing suggestions with the same name will be overwritten 2330 remove_saved_features : bool, default False 2331 if `True`, the dataset will be deleted after the computation 2332 2333 """ 2334 self._check_suggestions_validity(suggestions_name, force=force) 2335 print(f"SUGGESTION {suggestions_name}") 2336 task, parameters, mode = self._make_task_prediction( 2337 suggestions_name, 2338 episode_name, 2339 parameters_update, 2340 load_epoch, 2341 data_path=data_path, 2342 file_paths=file_paths, 2343 mode=mode, 2344 ) 2345 if classes is None: 2346 classes = list(task.behaviors_dict().values()) 2347 score_tensors = task.generate_bald_score( 2348 classes, augment_n, num_models, kernel_size 2349 ) 2350 intervals = self._suggest_intervals( 2351 task.dataset(mode), score_tensors, n_frames, min_length 2352 ) 2353 if remove_saved_features: 2354 self._remove_stores(parameters) 2355 suggestions_path = os.path.join( 2356 self.project_path, 2357 "results", 2358 "suggestions", 2359 suggestions_name, 2360 ) 2361 if not os.path.exists(suggestions_path): 2362 os.mkdir(suggestions_path) 2363 with open( 2364 os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb" 2365 ) as f: 2366 pickle.dump(intervals, f) 2367 meta_parameters = { 2368 "suggestion_episode": episode_name, 2369 "suggestion_load_epoch": load_epoch, 2370 "suggestion_classes": classes, 2371 "min_frames_suggestion": min_length, 2372 "augment_n": augment_n, 2373 "method": f"BALD:{num_models}", 2374 "num_frames": n_frames, 2375 } 2376 self._save_suggestions(suggestions_name, {}, meta_parameters) 2377 print("\n")
Generate an active learning file based on Bayesian Active Learning by Disagreement.
Parameters
suggestions_name : str
    the name of the suggestion
episode_name : str
    the name of the episode to load the model from
load_epoch : int, optional
    the index of the epoch to load the model from (if None, the last one will be used)
classes : list, optional
    a list of classes to look at (by default all)
n_frames : int, default 10000
    the threshold total number of frames in the suggested intervals (in the end result it will most likely
    be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
    with the set parameters)
num_models : int, default 10
    the number of dropout masks to apply
kernel_size : int, default 11
    the size of the smoothing kernel applied to the discrete results
min_length : int, default 60
    the minimum number of frames in one interval
augment_n : int, default 0
    the number of augmentations to average the predictions over
data_path : str, optional
    the path to a data folder (by default, the project data is used)
file_paths : set, optional
    a list of file paths (by default, the project data is used)
parameters_update : dict, optional
    a dictionary of parameter updates
mode : {"test", "val", "train", "all"}
    the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
    by default set to 'test' if the test subset if not empty, or to 'val' otherwise)
force : bool, default False
    if True, existing suggestions with the same name will be overwritten
remove_saved_features : bool, default False
    if True, the dataset will be deleted after the computation
2379 def list_episodes( 2380 self, 2381 episode_names: List = None, 2382 value_filter: str = "", 2383 display_parameters: List = None, 2384 print_results: bool = True, 2385 ) -> pd.DataFrame: 2386 """Get a filtered pandas dataframe with episode metadata. 2387 2388 Parameters 2389 ---------- 2390 episode_names : list 2391 a list of strings of episode names 2392 value_filter : str 2393 a string of filters to apply; of this general structure: 2394 'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g. 2395 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' 2396 display_parameters : list 2397 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2398 print_results : bool, default True 2399 if True, the result will be printed to standard output 2400 2401 Returns 2402 ------- 2403 pd.DataFrame 2404 the filtered dataframe 2405 2406 """ 2407 episodes = self._episodes().list_episodes( 2408 episode_names, value_filter, display_parameters 2409 ) 2410 if print_results: 2411 print("TRAINING EPISODES") 2412 print(episodes) 2413 print("\n") 2414 return episodes
Get a filtered pandas dataframe with episode metadata.
Parameters
episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1::(>/<=/>=/=)value1,group_name2/par_name2::(>/<=/>=/=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output
Returns
pd.DataFrame the filtered dataframe
2416 def list_predictions( 2417 self, 2418 episode_names: List = None, 2419 value_filter: str = "", 2420 display_parameters: List = None, 2421 print_results: bool = True, 2422 ) -> pd.DataFrame: 2423 """Get a filtered pandas dataframe with prediction metadata. 2424 2425 Parameters 2426 ---------- 2427 episode_names : list 2428 a list of strings of episode names 2429 value_filter : str 2430 a string of filters to apply; of this general structure: 2431 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2432 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2433 display_parameters : list 2434 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2435 print_results : bool, default True 2436 if True, the result will be printed to standard output 2437 2438 Returns 2439 ------- 2440 pd.DataFrame 2441 the filtered dataframe 2442 2443 """ 2444 predictions = self._predictions().list_episodes( 2445 episode_names, value_filter, display_parameters 2446 ) 2447 if print_results: 2448 print("PREDICTIONS") 2449 print(predictions) 2450 print("\n") 2451 return predictions
Get a filtered pandas dataframe with prediction metadata.
Parameters
episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output
Returns
pd.DataFrame the filtered dataframe
2453 def list_suggestions( 2454 self, 2455 suggestions_names: List = None, 2456 value_filter: str = "", 2457 display_parameters: List = None, 2458 print_results: bool = True, 2459 ) -> pd.DataFrame: 2460 """Get a filtered pandas dataframe with prediction metadata. 2461 2462 Parameters 2463 ---------- 2464 suggestions_names : list 2465 a list of strings of suggestion names 2466 value_filter : str 2467 a string of filters to apply; of this general structure: 2468 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2469 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2470 display_parameters : list 2471 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2472 print_results : bool, default True 2473 if True, the result will be printed to standard output 2474 2475 Returns 2476 ------- 2477 pd.DataFrame 2478 the filtered dataframe 2479 2480 """ 2481 suggestions = self._suggestions().list_episodes( 2482 suggestions_names, value_filter, display_parameters 2483 ) 2484 if print_results: 2485 print("SUGGESTIONS") 2486 print(suggestions) 2487 print("\n") 2488 return suggestions
Get a filtered pandas dataframe with prediction metadata.
Parameters
suggestions_names : list a list of strings of suggestion names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output
Returns
pd.DataFrame the filtered dataframe
2490 def list_searches( 2491 self, 2492 search_names: List = None, 2493 value_filter: str = "", 2494 display_parameters: List = None, 2495 print_results: bool = True, 2496 ) -> pd.DataFrame: 2497 """Get a filtered pandas dataframe with hyperparameter search metadata. 2498 2499 Parameters 2500 ---------- 2501 search_names : list 2502 a list of strings of search names 2503 value_filter : str 2504 a string of filters to apply; of this general structure: 2505 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 2506 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 2507 display_parameters : list 2508 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 2509 print_results : bool, default True 2510 if True, the result will be printed to standard output 2511 2512 Returns 2513 ------- 2514 pd.DataFrame 2515 the filtered dataframe 2516 2517 """ 2518 searches = self._searches().list_episodes( 2519 search_names, value_filter, display_parameters 2520 ) 2521 if print_results: 2522 print("SEARCHES") 2523 print(searches) 2524 print("\n") 2525 return searches
Get a filtered pandas dataframe with hyperparameter search metadata.
Parameters
search_names : list a list of strings of search names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output
Returns
pd.DataFrame the filtered dataframe
2527 def get_best_parameters( 2528 self, 2529 search_name: str, 2530 round_to_binary: List = None, 2531 ): 2532 """Get the best parameters found by a search. 2533 2534 Parameters 2535 ---------- 2536 search_name : str 2537 the name of the search 2538 round_to_binary : list, default None 2539 a list of parameters to round to binary values 2540 2541 Returns 2542 ------- 2543 best_params : dict 2544 a dictionary of the best parameters where the keys are in '{group}/{name}' format 2545 2546 """ 2547 params, model = self._searches().get_best_params( 2548 search_name, round_to_binary=round_to_binary 2549 ) 2550 params = self._update(params, {"general": {"model_name": model}}) 2551 return params
Get the best parameters found by a search.
Parameters
search_name : str the name of the search round_to_binary : list, default None a list of parameters to round to binary values
Returns
best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format
2553 def list_best_parameters( 2554 self, search_name: str, print_results: bool = True 2555 ) -> Dict: 2556 """Get the raw dictionary of best parameters found by a search. 2557 2558 Parameters 2559 ---------- 2560 search_name : str 2561 the name of the search 2562 print_results : bool, default True 2563 if True, the result will be printed to standard output 2564 2565 Returns 2566 ------- 2567 best_params : dict 2568 a dictionary of the best parameters where the keys are in '{group}/{name}' format 2569 2570 """ 2571 params = self._searches().get_best_params_raw(search_name) 2572 if print_results: 2573 print(f"SEARCH RESULTS {search_name}") 2574 for k, v in params.items(): 2575 print(f"{k}: {v}") 2576 print("\n") 2577 return params
Get the raw dictionary of best parameters found by a search.
Parameters
search_name : str the name of the search print_results : bool, default True if True, the result will be printed to standard output
Returns
best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format
2579 def plot_episodes( 2580 self, 2581 episode_names: List, 2582 metrics: List | str, 2583 modes: List | str = None, 2584 title: str = None, 2585 episode_labels: List = None, 2586 save_path: str = None, 2587 add_hlines: List = None, 2588 epoch_limits: List = None, 2589 colors: List = None, 2590 add_highpoint_hlines: bool = False, 2591 remove_box: bool = False, 2592 font_size: float = None, 2593 linewidth: float = None, 2594 return_ax: bool = False, 2595 ) -> None: 2596 """Plot episode training curves. 2597 2598 Parameters 2599 ---------- 2600 episode_names : list 2601 a list of episode names to plot; to plot to episodes in one line combine them in a list 2602 (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) 2603 metrics : list 2604 a list of metric to plot 2605 modes : list, optional 2606 a list of modes to plot ('train' and/or 'val'; `['val']` by default) 2607 title : str, optional 2608 title for the plot 2609 episode_labels : list, optional 2610 a list of strings used to label the curves (has to be the same length as episode_names) 2611 save_path : str, optional 2612 the path to save the resulting plot 2613 add_hlines : list, optional 2614 a list of float values (or (value, label) tuples) to mark with horizontal lines 2615 epoch_limits : list, optional 2616 a list of (min, max) tuples to set the x-axis limits for each episode 2617 colors: list, optional 2618 a list of matplotlib colors 2619 add_highpoint_hlines : bool, default False 2620 if `True`, horizontal lines will be added at the highest value of each episode 2621 """ 2622 2623 if isinstance(metrics, str): 2624 metrics = [metrics] 2625 if isinstance(modes, str): 2626 modes = [modes] 2627 2628 if font_size is not None: 2629 font = {"size": font_size} 2630 rc("font", **font) 2631 if modes is None: 2632 modes = ["val"] 2633 if add_hlines is None: 2634 add_hlines = [] 2635 logs = [] 2636 epochs = [] 2637 labels = [] 2638 if episode_labels is not None: 2639 assert len(episode_labels) == len(episode_names) 2640 for name_i, name in enumerate(episode_names): 2641 log_params = product(metrics, modes) 2642 for metric, mode in log_params: 2643 if episode_labels is not None: 2644 label = episode_labels[name_i] 2645 else: 2646 label = deepcopy(name) 2647 if len(modes) != 1: 2648 label += f"_{mode}" 2649 if len(metrics) != 1: 2650 label += f"_{metric}" 2651 labels.append(label) 2652 if isinstance(name, Iterable) and not isinstance(name, str): 2653 epoch_list = defaultdict(lambda: []) 2654 multi_logs = defaultdict(lambda: []) 2655 for i, n in enumerate(name): 2656 runs = self._episodes().get_runs(n) 2657 if len(runs) > 1: 2658 for run in runs: 2659 if "::" in run: 2660 index = run.split("::")[-1] 2661 else: 2662 index = run.split("#")[-1] 2663 if multi_logs[index] == []: 2664 if multi_logs["null"] is None: 2665 raise RuntimeError( 2666 "The run indices are not consistent across episodes!" 2667 ) 2668 else: 2669 multi_logs[index] += multi_logs["null"] 2670 multi_logs[index] += list( 2671 self._episode(run).get_metric_log(mode, metric) 2672 ) 2673 start = ( 2674 0 2675 if len(epoch_list[index]) == 0 2676 else epoch_list[index][-1] 2677 ) 2678 epoch_list[index] += [ 2679 x + start 2680 for x in self._episode(run).get_epoch_list(mode) 2681 ] 2682 multi_logs["null"] = None 2683 else: 2684 if len(multi_logs.keys()) > 1: 2685 raise RuntimeError( 2686 "Cannot plot a single-run episode after a multi-run episode!" 2687 ) 2688 multi_logs["null"] += list( 2689 self._episode(n).get_metric_log(mode, metric) 2690 ) 2691 start = ( 2692 0 2693 if len(epoch_list["null"]) == 0 2694 else epoch_list["null"][-1] 2695 ) 2696 epoch_list["null"] += [ 2697 x + start for x in self._episode(n).get_epoch_list(mode) 2698 ] 2699 if len(multi_logs.keys()) == 1: 2700 log = multi_logs["null"] 2701 epochs.append(epoch_list["null"]) 2702 else: 2703 log = tuple([v for k, v in multi_logs.items() if k != "null"]) 2704 epochs.append( 2705 tuple([v for k, v in epoch_list.items() if k != "null"]) 2706 ) 2707 else: 2708 runs = self._episodes().get_runs(name) 2709 if len(runs) > 1: 2710 log = [] 2711 for run in runs: 2712 tracked_metrics = self._episode(run).get_metrics() 2713 if metric in tracked_metrics: 2714 log.append( 2715 list( 2716 self._episode(run).get_metric_log(mode, metric) 2717 ) 2718 ) 2719 else: 2720 relevant = [] 2721 for m in tracked_metrics: 2722 m_split = m.split("_") 2723 if ( 2724 "_".join(m_split[:-1]) == metric 2725 and m_split[-1].isnumeric() 2726 ): 2727 relevant.append(m) 2728 if len(relevant) == 0: 2729 raise ValueError( 2730 f"The {metric} metric was not tracked at {run}" 2731 ) 2732 arr = 0 2733 for m in relevant: 2734 arr += self._episode(run).get_metric_log(mode, m) 2735 arr /= len(relevant) 2736 log.append(list(arr)) 2737 log = tuple(log) 2738 epochs.append( 2739 tuple( 2740 [ 2741 self._episode(run).get_epoch_list(mode) 2742 for run in runs 2743 ] 2744 ) 2745 ) 2746 else: 2747 tracked_metrics = self._episode(name).get_metrics() 2748 if metric in tracked_metrics: 2749 log = list(self._episode(name).get_metric_log(mode, metric)) 2750 else: 2751 relevant = [] 2752 for m in tracked_metrics: 2753 m_split = m.split("_") 2754 if ( 2755 "_".join(m_split[:-1]) == metric 2756 and m_split[-1].isnumeric() 2757 ): 2758 relevant.append(m) 2759 if len(relevant) == 0: 2760 raise ValueError( 2761 f"The {metric} metric was not tracked at {name}" 2762 ) 2763 arr = 0 2764 for m in relevant: 2765 arr += self._episode(name).get_metric_log(mode, m) 2766 arr /= len(relevant) 2767 log = list(arr) 2768 epochs.append(self._episode(name).get_epoch_list(mode)) 2769 logs.append(log) 2770 # if episode_labels is not None: 2771 # print(f'{len(episode_labels)=}, {len(logs)=}') 2772 # if len(episode_labels) != len(logs): 2773 2774 # raise ValueError( 2775 # f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of " 2776 # f"curves ({len(logs)})!" 2777 # ) 2778 # else: 2779 # labels = episode_labels 2780 if colors is None: 2781 colors = cm.rainbow(np.linspace(0, 1, len(logs))) 2782 if len(colors) != len(logs): 2783 raise ValueError( 2784 "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!" 2785 ) 2786 f, ax = plt.subplots() 2787 length = 0 2788 for log, label, color, epoch_list in zip(logs, labels, colors, epochs): 2789 if type(log) is list: 2790 if len(log) > length: 2791 length = len(log) 2792 ax.plot( 2793 epoch_list, 2794 log, 2795 label=label, 2796 color=color, 2797 ) 2798 if add_highpoint_hlines: 2799 ax.axhline(np.max(log), linestyle="dashed", color=color) 2800 else: 2801 for l, xx in zip(log, epoch_list): 2802 if len(l) > length: 2803 length = len(l) 2804 ax.plot( 2805 xx, 2806 l, 2807 color=color, 2808 alpha=0.2, 2809 ) 2810 if not all([len(x) == len(log[0]) for x in log]): 2811 warnings.warn( 2812 f"Got logs with unequal lengths in parallel runs for {label}" 2813 ) 2814 log = list(log) 2815 epoch_list = list(epoch_list) 2816 for i, x in enumerate(epoch_list): 2817 to_remove = [] 2818 for j, y in enumerate(x[1:]): 2819 if y <= x[j - 1]: 2820 y_ind = x.index(y) 2821 to_remove += list(range(y_ind, j)) 2822 epoch_list[i] = [ 2823 y for j, y in enumerate(x) if j not in to_remove 2824 ] 2825 log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove] 2826 length = min([len(x) for x in log]) 2827 for i in range(len(log)): 2828 log[i] = log[i][:length] 2829 epoch_list[i] = epoch_list[i][:length] 2830 if not all([x == epoch_list[0] for x in epoch_list]): 2831 raise RuntimeError( 2832 f"Got different epoch indices in parallel runs for {label}" 2833 ) 2834 mean = np.array(log).mean(0) 2835 ax.plot( 2836 epoch_list[0], 2837 mean, 2838 label=label, 2839 color=color, 2840 linewidth=linewidth, 2841 ) 2842 if add_highpoint_hlines: 2843 ax.axhline(np.max(mean), linestyle="dashed", color=color) 2844 for x in add_hlines: 2845 label = None 2846 if isinstance(x, Iterable): 2847 x, label = x 2848 ax.axhline(x, label=label) 2849 ax.set_xlim((0, length)) 2850 2851 ax.legend() 2852 ax.set_xlabel("epochs") 2853 if len(metrics) == 1: 2854 ax.set_ylabel(metrics[0]) 2855 else: 2856 ax.set_ylabel("value") 2857 if title is None: 2858 if len(episode_names) == 1: 2859 title = episode_names[0] 2860 elif len(metrics) == 1: 2861 title = metrics[0] 2862 if epoch_limits is not None: 2863 ax.set_xlim(epoch_limits) 2864 if title is not None: 2865 ax.set_title(title) 2866 if remove_box: 2867 ax.box(False) 2868 if return_ax: 2869 return ax 2870 if save_path is not None: 2871 plt.savefig(save_path) 2872 plt.show()
Plot episode training curves.
Parameters
episode_names : list
    a list of episode names to plot; to plot to episodes in one line combine them in a list
    (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
metrics : list
    a list of metric to plot
modes : list, optional
    a list of modes to plot ('train' and/or 'val'; ['val'] by default)
title : str, optional
    title for the plot
episode_labels : list, optional
    a list of strings used to label the curves (has to be the same length as episode_names)
save_path : str, optional
    the path to save the resulting plot
add_hlines : list, optional
    a list of float values (or (value, label) tuples) to mark with horizontal lines
epoch_limits : list, optional
    a list of (min, max) tuples to set the x-axis limits for each episode
colors: list, optional
    a list of matplotlib colors
add_highpoint_hlines : bool, default False
    if True, horizontal lines will be added at the highest value of each episode
2874 def update_parameters( 2875 self, 2876 parameters_update: Dict = None, 2877 load_search: str = None, 2878 load_parameters: List = None, 2879 round_to_binary: List = None, 2880 ) -> None: 2881 """Update the parameters in the project config files. 2882 2883 Parameters 2884 ---------- 2885 parameters_update : dict, optional 2886 a dictionary of parameter updates 2887 load_search : str, optional 2888 the name of hyperparameter search results to load to config 2889 load_parameters : list, optional 2890 a list of lists of string names of the parameters to load from the searches 2891 round_to_binary : list, optional 2892 a list of string names of the loaded parameters that should be rounded to the nearest power of two 2893 2894 """ 2895 keys = [ 2896 "general", 2897 "losses", 2898 "metrics", 2899 "ssl", 2900 "training", 2901 "data", 2902 ] 2903 parameters = self._read_parameters(catch_blanks=False) 2904 if parameters_update is not None: 2905 model_params = ( 2906 parameters_update.pop("model") if "model" in parameters_update else None 2907 ) 2908 feat_params = ( 2909 parameters_update.pop("features") 2910 if "features" in parameters_update 2911 else None 2912 ) 2913 aug_params = ( 2914 parameters_update.pop("augmentations") 2915 if "augmentations" in parameters_update 2916 else None 2917 ) 2918 2919 parameters = self._update(parameters, parameters_update) 2920 model_name = parameters["general"]["model_name"] 2921 parameters["model"] = self._open_yaml( 2922 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2923 ) 2924 if model_params is not None: 2925 parameters["model"] = self._update(parameters["model"], model_params) 2926 feat_name = parameters["general"]["feature_extraction"] 2927 parameters["features"] = self._open_yaml( 2928 os.path.join( 2929 self.project_path, "config", "features", f"{feat_name}.yaml" 2930 ) 2931 ) 2932 if feat_params is not None: 2933 parameters["features"] = self._update( 2934 parameters["features"], feat_params 2935 ) 2936 aug_name = options.extractor_to_transformer[ 2937 parameters["general"]["feature_extraction"] 2938 ] 2939 parameters["augmentations"] = self._open_yaml( 2940 os.path.join( 2941 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2942 ) 2943 ) 2944 if aug_params is not None: 2945 parameters["augmentations"] = self._update( 2946 parameters["augmentations"], aug_params 2947 ) 2948 if load_search is not None: 2949 parameters_update, model_name = self._searches().get_best_params( 2950 load_search, load_parameters, round_to_binary 2951 ) 2952 parameters["general"]["model_name"] = model_name 2953 parameters["model"] = self._open_yaml( 2954 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2955 ) 2956 parameters = self._update(parameters, parameters_update) 2957 for key in keys: 2958 with open( 2959 os.path.join(self.project_path, "config", f"{key}.yaml"), 2960 "w", 2961 encoding="utf-8", 2962 ) as f: 2963 YAML().dump(parameters[key], f) 2964 model_name = parameters["general"]["model_name"] 2965 model_path = os.path.join( 2966 self.project_path, "config", "model", f"{model_name}.yaml" 2967 ) 2968 with open(model_path, "w", encoding="utf-8") as f: 2969 YAML().dump(parameters["model"], f) 2970 features_name = parameters["general"]["feature_extraction"] 2971 features_path = os.path.join( 2972 self.project_path, "config", "features", f"{features_name}.yaml" 2973 ) 2974 with open(features_path, "w", encoding="utf-8") as f: 2975 YAML().dump(parameters["features"], f) 2976 aug_name = options.extractor_to_transformer[features_name] 2977 aug_path = os.path.join( 2978 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2979 ) 2980 with open(aug_path, "w", encoding="utf-8") as f: 2981 YAML().dump(parameters["augmentations"], f)
Update the parameters in the project config files.
Parameters
parameters_update : dict, optional a dictionary of parameter updates load_search : str, optional the name of hyperparameter search results to load to config load_parameters : list, optional a list of lists of string names of the parameters to load from the searches round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two
2983 def get_summary( 2984 self, 2985 episode_names: list, 2986 method: str = "last", 2987 average: int = 1, 2988 metrics: List = None, 2989 return_values: bool = False, 2990 ) -> Dict: 2991 """Get a summary of episode statistics. 2992 2993 If an episode has multiple runs, the statistics will be aggregated over all of them. 2994 2995 Parameters 2996 ---------- 2997 episode_names : str 2998 the names of the episodes 2999 method : ["best", "last"] 3000 the method for choosing the epochs 3001 average : int, default 1 3002 the number of epochs to average over (for each run) 3003 metrics : list, optional 3004 a list of metrics 3005 3006 Returns 3007 ------- 3008 statistics : dict 3009 a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean 3010 and 'std' for the standard deviation 3011 3012 """ 3013 runs = [] 3014 for episode_name in episode_names: 3015 runs_ep = self._episodes().get_runs(episode_name) 3016 if len(runs_ep) == 0: 3017 raise RuntimeError( 3018 f"There is no {episode_name} episode in the project memory" 3019 ) 3020 runs += runs_ep 3021 if metrics is None: 3022 metrics = self._episode(runs[0]).get_metrics() 3023 3024 values = {m: [] for m in metrics} 3025 for run in runs: 3026 for m in metrics: 3027 log = self._episode(run).get_metric_log(mode="val", metric_name=m) 3028 if method == "best": 3029 log = sorted(log) 3030 values[m] += list(log[-average:]) 3031 elif method == "last": 3032 if len(log) == 0: 3033 episodes = self._episodes().data 3034 if average == 1 and ("results", m) in episodes.columns: 3035 values[m] += [episodes.loc[run, ("results", m)]] 3036 else: 3037 raise RuntimeError(f"Did not find {m} metric for {run} run") 3038 values[m] += list(log[-average:]) 3039 elif method.startswith("epoch"): 3040 epoch = int(method[5:]) - 1 3041 pars = self._episodes().load_parameters(run) 3042 step = int(pars["training"]["validation_interval"]) 3043 values[m] += [log[epoch // step]] 3044 else: 3045 raise ValueError( 3046 f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']" 3047 ) 3048 statistics = defaultdict(lambda: {}) 3049 for m, v in values.items(): 3050 statistics[m]["mean"] = np.mean(v) 3051 statistics[m]["std"] = np.std(v) 3052 print(f"SUMMARY {episode_names}") 3053 for m, v in statistics.items(): 3054 print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}') 3055 print("\n") 3056 3057 return (dict(statistics), values) if return_values else dict(statistics)
Get a summary of episode statistics.
If an episode has multiple runs, the statistics will be aggregated over all of them.
Parameters
episode_names : str the names of the episodes method : ["best", "last"] the method for choosing the epochs average : int, default 1 the number of epochs to average over (for each run) metrics : list, optional a list of metrics
Returns
statistics : dict a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean and 'std' for the standard deviation
3059 @staticmethod 3060 def remove_project(name: str, projects_path: str = None) -> None: 3061 """Remove all project files and experiment records and results. 3062 3063 Parameters 3064 ---------- 3065 name : str 3066 the name of the project to remove 3067 projects_path : str, optional 3068 the path to the projects directory (by default the home DLC2Action directory) 3069 3070 """ 3071 if projects_path is None: 3072 projects_path = os.path.join(str(Path.home()), "DLC2Action") 3073 project_path = os.path.join(projects_path, name) 3074 if os.path.exists(project_path): 3075 shutil.rmtree(project_path)
Remove all project files and experiment records and results.
Parameters
name : str the name of the project to remove projects_path : str, optional the path to the projects directory (by default the home DLC2Action directory)
3077 def remove_saved_features( 3078 self, 3079 dataset_names: List = None, 3080 exceptions: List = None, 3081 remove_active: bool = False, 3082 ) -> None: 3083 """Remove saved pre-computed dataset feature files. 3084 3085 By default, all features will be deleted. 3086 No essential information can get lost, storing them only saves time. Be careful with deleting datasets 3087 while training or inference is happening though. 3088 3089 Parameters 3090 ---------- 3091 dataset_names : list, optional 3092 a list of dataset names to delete (by default all names are added) 3093 exceptions : list, optional 3094 a list of dataset names to not be deleted 3095 remove_active : bool, default False 3096 if `False`, datasets used by unfinished episodes will not be deleted 3097 3098 """ 3099 print("Removing datasets...") 3100 if dataset_names is None: 3101 dataset_names = [] 3102 if exceptions is None: 3103 exceptions = [] 3104 if not remove_active: 3105 exceptions += self._episodes().get_active_datasets() 3106 dataset_path = os.path.join(self.project_path, "saved_datasets") 3107 if os.path.exists(dataset_path): 3108 if dataset_names == []: 3109 dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)]) 3110 3111 to_remove = [ 3112 x 3113 for x in dataset_names 3114 if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions 3115 ] 3116 if len(to_remove) > 2: 3117 to_remove = tqdm(to_remove) 3118 for dataset in to_remove: 3119 shutil.rmtree(os.path.join(dataset_path, dataset)) 3120 to_remove = [ 3121 f"{x}.pickle" 3122 for x in dataset_names 3123 if os.path.exists(os.path.join(dataset_path, f"{x}.pickle")) 3124 and x not in exceptions 3125 ] 3126 for dataset in to_remove: 3127 os.remove(os.path.join(dataset_path, dataset)) 3128 names = self._saved_datasets().dataset_names() 3129 self._saved_datasets().remove(names) 3130 print("\n")
Remove saved pre-computed dataset feature files.
By default, all features will be deleted. No essential information can get lost, storing them only saves time. Be careful with deleting datasets while training or inference is happening though.
Parameters
dataset_names : list, optional
    a list of dataset names to delete (by default all names are added)
exceptions : list, optional
    a list of dataset names to not be deleted
remove_active : bool, default False
    if False, datasets used by unfinished episodes will not be deleted
3132 def remove_extra_checkpoints( 3133 self, episode_names: List = None, exceptions: List = None 3134 ) -> None: 3135 """Remove intermediate model checkpoint files (only leave the files for the last epoch). 3136 3137 By default, all intermediate checkpoints will be deleted. 3138 Files in the model folder that are not associated with any record in the meta files are also deleted. 3139 3140 Parameters 3141 ---------- 3142 episode_names : list, optional 3143 a list of episode names to clean (by default all names are added) 3144 exceptions : list, optional 3145 a list of episode names to not clean 3146 3147 """ 3148 model_path = os.path.join(self.project_path, "results", "model") 3149 try: 3150 all_names = self._episodes().data.index 3151 except: 3152 all_names = os.listdir(model_path) 3153 if episode_names is None: 3154 episode_names = all_names 3155 if exceptions is None: 3156 exceptions = [] 3157 to_remove = [x for x in episode_names if x not in exceptions] 3158 folders = os.listdir(model_path) 3159 for folder in folders: 3160 if folder not in all_names: 3161 shutil.rmtree(os.path.join(model_path, folder)) 3162 elif folder in to_remove: 3163 files = os.listdir(os.path.join(model_path, folder)) 3164 for file in sorted(files)[:-1]: 3165 os.remove(os.path.join(model_path, folder, file))
Remove intermediate model checkpoint files (only leave the files for the last epoch).
By default, all intermediate checkpoints will be deleted. Files in the model folder that are not associated with any record in the meta files are also deleted.
Parameters
episode_names : list, optional a list of episode names to clean (by default all names are added) exceptions : list, optional a list of episode names to not clean
3167 def remove_search(self, search_name: str) -> None: 3168 """Remove a hyperparameter search record. 3169 3170 Parameters 3171 ---------- 3172 search_name : str 3173 the name of the search to remove 3174 3175 """ 3176 self._searches().remove_episode(search_name) 3177 graph_path = os.path.join(self.project_path, "results", "searches", search_name) 3178 if os.path.exists(graph_path): 3179 shutil.rmtree(graph_path)
Remove a hyperparameter search record.
Parameters
search_name : str the name of the search to remove
3181 def remove_suggestion(self, suggestion_name: str) -> None: 3182 """Remove a suggestion record. 3183 3184 Parameters 3185 ---------- 3186 suggestion_name : str 3187 the name of the suggestion to remove 3188 3189 """ 3190 self._suggestions().remove_episode(suggestion_name) 3191 suggestion_path = os.path.join( 3192 self.project_path, "results", "suggestions", suggestion_name 3193 ) 3194 if os.path.exists(suggestion_path): 3195 shutil.rmtree(suggestion_path)
Remove a suggestion record.
Parameters
suggestion_name : str the name of the suggestion to remove
3197 def remove_prediction(self, prediction_name: str) -> None: 3198 """Remove a prediction record. 3199 3200 Parameters 3201 ---------- 3202 prediction_name : str 3203 the name of the prediction to remove 3204 3205 """ 3206 self._predictions().remove_episode(prediction_name) 3207 prediction_path = self.prediction_path(prediction_name) 3208 if os.path.exists(prediction_path): 3209 shutil.rmtree(prediction_path)
Remove a prediction record.
Parameters
prediction_name : str the name of the prediction to remove
3211 def check_prediction_exists(self, prediction_name: str) -> str | None: 3212 """Check if a prediction exists. 3213 3214 Parameters 3215 ---------- 3216 prediction_name : str 3217 the name of the prediction to check 3218 3219 Returns 3220 ------- 3221 str | None 3222 the path to the prediction if it exists, `None` otherwise 3223 3224 """ 3225 prediction_path = self.prediction_path(prediction_name) 3226 if os.path.exists(prediction_path): 3227 return prediction_path 3228 return None
Check if a prediction exists.
Parameters
prediction_name : str the name of the prediction to check
Returns
str | None
    the path to the prediction if it exists, None otherwise
3230 def remove_episode(self, episode_name: str) -> None: 3231 """Remove all model, logs and metafile records related to an episode. 3232 3233 Parameters 3234 ---------- 3235 episode_name : str 3236 the name of the episode to remove 3237 3238 """ 3239 runs = self._episodes().get_runs(episode_name) 3240 runs.append(episode_name) 3241 for run in runs: 3242 self._episodes().remove_episode(run) 3243 model_path = os.path.join(self.project_path, "results", "model", run) 3244 if os.path.exists(model_path): 3245 shutil.rmtree(model_path) 3246 log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt") 3247 if os.path.exists(log_path): 3248 os.remove(log_path)
Remove all model, logs and metafile records related to an episode.
Parameters
episode_name : str the name of the episode to remove
3273 def prune_unfinished(self, exceptions: List = None) -> List: 3274 """Remove all interrupted episodes. 3275 3276 Remove all episodes that either don't have a log file or have less epochs in the log file than in 3277 the training parameters or have a model folder but not a record. Note that it can remove episodes that are 3278 currently running! 3279 3280 Parameters 3281 ---------- 3282 exceptions : list 3283 the episodes to keep even if they are interrupted 3284 3285 Returns 3286 ------- 3287 pruned : list 3288 a list of the episode names that were pruned 3289 3290 """ 3291 if exceptions is None: 3292 exceptions = [] 3293 unfinished = self._episodes().unfinished_episodes() 3294 unfinished = [x for x in unfinished if x not in exceptions] 3295 model_folders = os.listdir(os.path.join(self.project_path, "results", "model")) 3296 unfinished += [ 3297 x for x in model_folders if x not in self._episodes().list_episodes().index 3298 ] 3299 print(f"PRUNING {unfinished}") 3300 for episode_name in unfinished: 3301 self.remove_episode(episode_name) 3302 print(f"\n") 3303 return unfinished
Remove all interrupted episodes.
Remove all episodes that either don't have a log file or have less epochs in the log file than in the training parameters or have a model folder but not a record. Note that it can remove episodes that are currently running!
Parameters
exceptions : list the episodes to keep even if they are interrupted
Returns
pruned : list a list of the episode names that were pruned
3305 def prediction_path(self, prediction_name: str) -> str: 3306 """Get the path where prediction files are saved. 3307 3308 Parameters 3309 ---------- 3310 prediction_name : str 3311 name of the prediction 3312 3313 Returns 3314 ------- 3315 prediction_path : str 3316 the file path 3317 3318 """ 3319 return os.path.join( 3320 self.project_path, "results", "predictions", f"{prediction_name}" 3321 )
Get the path where prediction files are saved.
Parameters
prediction_name : str name of the prediction
Returns
prediction_path : str the file path
3323 def suggestion_path(self, suggestion_name: str) -> str: 3324 """Get the path where suggestion files are saved. 3325 3326 Parameters 3327 ---------- 3328 suggestion_name : str 3329 name of the prediction 3330 3331 Returns 3332 ------- 3333 suggestion_path : str 3334 the file path 3335 3336 """ 3337 return os.path.join( 3338 self.project_path, "results", "suggestions", f"{suggestion_name}" 3339 )
Get the path where suggestion files are saved.
Parameters
suggestion_name : str name of the prediction
Returns
suggestion_path : str the file path
3341 @classmethod 3342 def print_data_types(cls): 3343 """Print available data types.""" 3344 print("DATA TYPES:") 3345 for key, value in cls.data_types().items(): 3346 print(f"{key}:") 3347 print(value.__doc__)
Print available data types.
3349 @classmethod 3350 def print_annotation_types(cls): 3351 """Print available annotation types.""" 3352 print("ANNOTATION TYPES:") 3353 for key, value in cls.annotation_types(): 3354 print(f"{key}:") 3355 print(value.__doc__)
Print available annotation types.
3357 @staticmethod 3358 def data_types() -> List: 3359 """Get available data types. 3360 3361 Returns 3362 ------- 3363 data_types : list 3364 available data types 3365 3366 """ 3367 return options.input_stores
Get available data types.
Returns
data_types : list available data types
3369 @staticmethod 3370 def annotation_types() -> List: 3371 """Get available annotation types. 3372 3373 Returns 3374 ------- 3375 list 3376 available annotation types 3377 3378 """ 3379 return options.annotation_stores
Get available annotation types.
Returns
list available annotation types
3957 def set_main_parameters(self, model_name: str = None, metric_names: List = None): 3958 """Select the model and the metrics. 3959 3960 Parameters 3961 ---------- 3962 model_name : str, optional 3963 model name; run `project.help("model") to find out more 3964 metric_names : list, optional 3965 a list of metric function names; run `project.help("metrics") to find out more 3966 3967 """ 3968 pars = {"general": {}} 3969 if model_name is not None: 3970 assert model_name in options.models 3971 pars["general"]["model_name"] = model_name 3972 if metric_names is not None: 3973 for metric in metric_names: 3974 assert metric in options.metrics 3975 pars["general"]["metric_functions"] = metric_names 3976 self.update_parameters(pars)
Select the model and the metrics.
Parameters
model_name : str, optional
    model name; run project.help("model") to find out more
metric_names : list, optional
    a list of metric function names; runproject.help("metrics") to find out more
3978 def help(self, keyword: str = None): 3979 """Get information on available options. 3980 3981 Parameters 3982 ---------- 3983 keyword : str, optional 3984 the keyword for options (run without arguments to see which keywords are available) 3985 3986 """ 3987 if keyword is None: 3988 print("AVAILABLE HELP FUNCTIONS:") 3989 print("- Try running `project.help(keyword)` with the following keywords:") 3990 print(" - model: to get more information on available models,") 3991 print( 3992 " - features: to get more information on available feature extraction modes," 3993 ) 3994 print( 3995 " - partition_method: to get more information on available train/test/val partitioning methods," 3996 ) 3997 print(" - metrics: to see a list of available metric functions.") 3998 print(" - data: to see help for expected data structure") 3999 print( 4000 "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in." 4001 ) 4002 print( 4003 "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify" 4004 ) 4005 print( 4006 f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)." 4007 ) 4008 elif keyword == "model": 4009 print("MODELS:") 4010 for key, model in options.models.items(): 4011 print(f"{key}:") 4012 print(model.__doc__) 4013 elif keyword == "features": 4014 print("FEATURE EXTRACTORS:") 4015 for key, extractor in options.feature_extractors.items(): 4016 print(f"{key}:") 4017 print(extractor.__doc__) 4018 elif keyword == "partition_method": 4019 print("PARTITION METHODS:") 4020 print( 4021 BehaviorDataset.partition_train_test_val.__doc__.split( 4022 "The partitioning method:" 4023 )[1].split("val_frac :")[0] 4024 ) 4025 elif keyword == "metrics": 4026 print("METRICS:") 4027 for key, metric in options.metrics.items(): 4028 print(f"{key}:") 4029 print(metric.__doc__) 4030 elif keyword == "data": 4031 print("DATA:") 4032 print(f"Video data: {self.data_type}") 4033 print(options.input_stores[self.data_type].__doc__) 4034 print(f"Annotation data: {self.annotation_type}") 4035 print(options.annotation_stores[self.annotation_type].__doc__) 4036 print( 4037 "Annotation path and data path don't have to be separate, you can keep everything in one folder." 4038 ) 4039 else: 4040 raise ValueError(f"The {keyword} keyword is not recognized") 4041 print("\n")
Get information on available options.
Parameters
keyword : str, optional the keyword for options (run without arguments to see which keywords are available)
4082 def list_blanks(self, blanks=None): 4083 """List parameters that need to be filled in. 4084 4085 Parameters 4086 ---------- 4087 blanks : list, optional 4088 a list of the parameters to list, if already known 4089 4090 """ 4091 if blanks is None: 4092 blanks = self._get_blanks() 4093 if len(blanks) > 0: 4094 to_update = defaultdict(lambda: []) 4095 for b, k, c in blanks: 4096 to_update[b].append((k, c)) 4097 print("Before running experiments, please update all the blanks.") 4098 print("To do that, you can run this.") 4099 print("--------------------------------------------------------") 4100 print(f"project.update_parameters(") 4101 print(f" {{") 4102 for big_key, keys in to_update.items(): 4103 print(f' "{big_key}": {{') 4104 for key, comment in keys: 4105 print(f' "{key}": ..., {comment}') 4106 print(f" }}") 4107 print(f" }}") 4108 print(")") 4109 print("--------------------------------------------------------") 4110 print("Replace ... with relevant values.") 4111 else: 4112 print("There is no blanks left!")
List parameters that need to be filled in.
Parameters
blanks : list, optional a list of the parameters to list, if already known
4114 def list_basic_parameters( 4115 self, 4116 ): 4117 """Get a list of most relevant parameters and code to modify them.""" 4118 parameters = self._read_parameters() 4119 print("BASIC PARAMETERS:") 4120 model_name = parameters["general"]["model_name"] 4121 metric_names = parameters["general"]["metric_functions"] 4122 loss_name = parameters["general"]["loss_function"] 4123 feature_extraction = parameters["general"]["feature_extraction"] 4124 print("Here is a list of current parameters.") 4125 print( 4126 "You can copy this code, change the parameters you want to set and run it to update the project config." 4127 ) 4128 print("--------------------------------------------------------") 4129 print("project.update_parameters(") 4130 print(" {") 4131 for group in ["general", "data", "training"]: 4132 print(f' "{group}": {{') 4133 for key in options.basic_parameters[group]: 4134 if key in parameters[group]: 4135 print( 4136 f' "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}' 4137 ) 4138 print(" },") 4139 print(' "losses": {') 4140 print(f' "{loss_name}": {{') 4141 for key in options.basic_parameters["losses"][loss_name]: 4142 if key in parameters["losses"][loss_name]: 4143 print( 4144 f' "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}' 4145 ) 4146 print(" },") 4147 print(" },") 4148 print(' "metrics": {') 4149 for metric in metric_names: 4150 print(f' "{metric}": {{') 4151 for key in parameters["metrics"][metric]: 4152 print( 4153 f' "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}' 4154 ) 4155 print(" },") 4156 print(" },") 4157 print(' "model": {') 4158 for key in options.basic_parameters["model"][model_name]: 4159 if key in parameters["model"]: 4160 print( 4161 f' "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}' 4162 ) 4163 4164 print(" },") 4165 print(' "features": {') 4166 for key in options.basic_parameters["features"][feature_extraction]: 4167 if key in parameters["features"]: 4168 print( 4169 f' "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}' 4170 ) 4171 4172 print(" },") 4173 print(' "augmentations": {') 4174 for key in options.basic_parameters["augmentations"][feature_extraction]: 4175 if key in parameters["augmentations"]: 4176 print( 4177 f' "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}' 4178 ) 4179 print(" },") 4180 print(" },") 4181 print(")") 4182 print("--------------------------------------------------------") 4183 print("\n")
Get a list of most relevant parameters and code to modify them.
5097 def count_classes( 5098 self, 5099 load_episode: str = None, 5100 parameters_update: Dict = None, 5101 remove_saved_features: bool = False, 5102 bouts: bool = True, 5103 ) -> Dict: 5104 """Get a dictionary of class counts in different modes. 5105 5106 Parameters 5107 ---------- 5108 load_episode : str, optional 5109 the episode settings to load 5110 parameters_update : dict, optional 5111 a dictionary of parameter updates (only for "data" and "general" categories) 5112 remove_saved_features : bool, default False 5113 if `True`, the dataset that is used for computation is then deleted 5114 bouts : bool, default False 5115 if `True`, instead of frame counts segment counts are returned 5116 5117 Returns 5118 ------- 5119 class_counts : dict 5120 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 5121 class names and values are class counts (in frames) 5122 5123 """ 5124 if load_episode is None: 5125 task, parameters = self._make_task_training( 5126 episode_name="_", parameters_update=parameters_update, throwaway=True 5127 ) 5128 else: 5129 task, parameters, _ = self._make_task_prediction( 5130 "_", 5131 load_episode=load_episode, 5132 parameters_update=parameters_update, 5133 ) 5134 class_counts = task.count_classes(bouts=bouts) 5135 behaviors = task.behaviors_dict() 5136 class_counts = { 5137 kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()} 5138 for kk, vv in class_counts.items() 5139 } 5140 if remove_saved_features: 5141 self._remove_stores(parameters) 5142 return class_counts
Get a dictionary of class counts in different modes.
Parameters
load_episode : str, optional
    the episode settings to load
parameters_update : dict, optional
    a dictionary of parameter updates (only for "data" and "general" categories)
remove_saved_features : bool, default False
    if True, the dataset that is used for computation is then deleted
bouts : bool, default False
    if True, instead of frame counts segment counts are returned
Returns
class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)
5144 def plot_class_distribution( 5145 self, 5146 parameters_update: Dict = None, 5147 frame_cutoff: int = 1, 5148 bout_cutoff: int = 1, 5149 print_full: bool = False, 5150 remove_saved_features: bool = False, 5151 save: str = None, 5152 ) -> None: 5153 """Make a class distribution plot. 5154 5155 You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset 5156 is created or loaded for the computation with the default parameters). 5157 5158 Parameters 5159 ---------- 5160 parameters_update : dict, optional 5161 a dictionary of parameter updates (only for "data" and "general" categories) 5162 frame_cutoff : int, default 1 5163 the minimum number of frames for a segment to be considered 5164 bout_cutoff : int, default 1 5165 the minimum number of bouts for a class to be considered 5166 print_full : bool, default False 5167 if `True`, the full class distribution is printed 5168 remove_saved_features : bool, default False 5169 if `True`, the dataset that is used for computation is then deleted 5170 5171 """ 5172 task, parameters = self._make_task_training( 5173 episode_name="_", parameters_update=parameters_update, throwaway=True 5174 ) 5175 cutoff = {True: bout_cutoff, False: frame_cutoff} 5176 for bouts in [True, False]: 5177 class_counts = task.count_classes(bouts=bouts) 5178 if print_full: 5179 print("Bouts:" if bouts else "Frames:") 5180 for k, v in class_counts.items(): 5181 if sum(v.values()) != 0: 5182 print(f" {k}:") 5183 values, keys = zip( 5184 *[ 5185 x 5186 for x in sorted(zip(v.values(), v.keys()), reverse=True) 5187 if x[-1] != -100 5188 ] 5189 ) 5190 for kk, vv in zip(keys, values): 5191 print(f" {task.behaviors_dict()[kk]}: {vv}") 5192 class_counts = { 5193 kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]} 5194 for kk, vv in class_counts.items() 5195 } 5196 for key, d in class_counts.items(): 5197 if sum(d.values()) != 0: 5198 values, keys = zip( 5199 *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100] 5200 ) 5201 keys = [task.behaviors_dict()[x] for x in keys] 5202 plt.bar(keys, values) 5203 plt.title(key) 5204 plt.xticks(rotation=45, ha="right") 5205 if bouts: 5206 plt.ylabel("bouts") 5207 else: 5208 plt.ylabel("frames") 5209 plt.tight_layout() 5210 5211 if save is None: 5212 plt.savefig(save) 5213 plt.close() 5214 else: 5215 plt.show() 5216 if remove_saved_features: 5217 self._remove_stores(parameters)
Make a class distribution plot.
You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset is created or loaded for the computation with the default parameters).
Parameters
parameters_update : dict, optional
    a dictionary of parameter updates (only for "data" and "general" categories)
frame_cutoff : int, default 1
    the minimum number of frames for a segment to be considered
bout_cutoff : int, default 1
    the minimum number of bouts for a class to be considered
print_full : bool, default False
    if True, the full class distribution is printed
remove_saved_features : bool, default False
    if True, the dataset that is used for computation is then deleted
5710 def plot_confusion_matrix( 5711 self, 5712 episode_name: str, 5713 load_epoch: int = None, 5714 parameters_update: Dict = None, 5715 metric: str = "recall", 5716 mode: str = "val", 5717 remove_saved_features: bool = False, 5718 save_path: str = None, 5719 cmap: str = "viridis", 5720 ) -> Tuple[ndarray, Iterable]: 5721 """Make a confusion matrix plot and return the data. 5722 5723 If the annotation is non-exclusive, only false positive labels are considered. 5724 5725 Parameters 5726 ---------- 5727 episode_name : str 5728 the name of the episode to load 5729 load_epoch : int, optional 5730 the index of the epoch to load (by default the last one is loaded) 5731 parameters_update : dict, optional 5732 a dictionary of parameter updates (only for "data" and "general" categories) 5733 metric : {"recall", "precision"} 5734 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 5735 into account, and if `type` is `"precision"`, only false negatives 5736 mode : {'val', 'all', 'test', 'train'} 5737 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 5738 remove_saved_features : bool, default False 5739 if `True`, the dataset that is used for computation is then deleted 5740 5741 Returns 5742 ------- 5743 confusion_matrix : np.ndarray 5744 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 5745 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 5746 `N_i` is the number of frames that have the i-th label in the ground truth 5747 classes : list 5748 a list of labels 5749 5750 """ 5751 task, parameters, mode = self._make_task_prediction( 5752 "_", 5753 load_episode=episode_name, 5754 load_epoch=load_epoch, 5755 parameters_update=parameters_update, 5756 mode=mode, 5757 ) 5758 dataset = task.dataset(mode) 5759 prediction = task.predict(dataset, raw_output=True) 5760 confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type) 5761 if remove_saved_features: 5762 self._remove_stores(parameters) 5763 fig, ax = plt.subplots(figsize=(len(classes), len(classes))) 5764 ax.imshow(confusion_matrix, cmap=cmap) 5765 # Show all ticks and label them with the respective list entries 5766 ax.set_xticks(np.arange(len(classes))) 5767 ax.set_xticklabels(classes) 5768 ax.set_yticks(np.arange(len(classes))) 5769 ax.set_yticklabels(classes) 5770 # Rotate the tick labels and set their alignment. 5771 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 5772 # Loop over data dimensions and create text annotations. 5773 for i in range(len(classes)): 5774 for j in range(len(classes)): 5775 ax.text( 5776 j, 5777 i, 5778 np.round(confusion_matrix[i, j], 2), 5779 ha="center", 5780 va="center", 5781 color="w", 5782 ) 5783 if metric is not None: 5784 ax.set_title(f"{metric} {episode_name}") 5785 else: 5786 ax.set_title(episode_name) 5787 fig.tight_layout() 5788 if save_path is None: 5789 plt.show() 5790 else: 5791 plt.savefig(save_path) 5792 plt.close() 5793 return confusion_matrix, classes
Make a confusion matrix plot and return the data.
If the annotation is non-exclusive, only false positive labels are considered.
Parameters
episode_name : str
    the name of the episode to load
load_epoch : int, optional
    the index of the epoch to load (by default the last one is loaded)
parameters_update : dict, optional
    a dictionary of parameter updates (only for "data" and "general" categories)
metric : {"recall", "precision"}
    for datasets with non-exclusive annotation, if type is "recall", only false positives are taken
    into account, and if type is "precision", only false negatives
mode : {'val', 'all', 'test', 'train'}
    the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
remove_saved_features : bool, default False
    if True, the dataset that is used for computation is then deleted
Returns
confusion_matrix : np.ndarray
    a confusion matrix of shape (#classes, #classes) where A[i, j] = F_ij/N_i, F_ij is the number of
    frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
    N_i is the number of frames that have the i-th label in the ground truth
classes : list
    a list of labels
5883 def plot_ethograms( 5884 self, 5885 episode_name: str, 5886 prediction_name: str, 5887 start: int = 0, 5888 end: int = -1, 5889 save_path: str = None, 5890 cmap_pred: str = "binary", 5891 cmap_gt: str = "binary", 5892 fontsize: int = 22, 5893 time_mode: str = "frames", 5894 fps: int = None, 5895 ): 5896 """Plot ethograms from start to end time (in frames) for ground truth and prediction""" 5897 params = self._read_parameters(catch_blanks=False) 5898 parameters = self._get_data_pars( 5899 params, 5900 ) 5901 if not save_path is None: 5902 os.makedirs(save_path, exist_ok=True) 5903 gt_files = [ 5904 f for f in self.data_path if f.endswith(parameters["annotation_suffix"]) 5905 ] 5906 pred_path = os.path.join( 5907 self.project_path, "results", "predictions", prediction_name 5908 ) 5909 pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)] 5910 for pred_path in pred_paths: 5911 predictions = load_pickle(pred_path) 5912 behaviors = self.get_behavior_dictionary(episode_name) 5913 gt_filename = os.path.basename(pred_path).replace( 5914 "_".join(["_" + prediction_name, "prediction.pickle"]), 5915 parameters["annotation_suffix"], 5916 ) 5917 if os.path.exists(os.path.join(self.data_path, gt_filename)): 5918 gt_data = load_pickle(os.path.join(self.data_path, gt_filename)) 5919 5920 self._plot_ethograms_gt_pred( 5921 gt_data, 5922 predictions, 5923 gt_data[1], 5924 behaviors, 5925 start=start, 5926 end=end, 5927 save=os.path.join( 5928 save_path, 5929 os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred", 5930 ), 5931 cmap_pred=cmap_pred, 5932 cmap_gt=cmap_gt, 5933 fontsize=fontsize, 5934 time_mode=time_mode, 5935 fps=fps, 5936 ) 5937 else: 5938 print("GT file not found")
Plot ethograms from start to end time (in frames) for ground truth and prediction
5999 def create_annotated_video( 6000 self, 6001 prediction_file_paths: list, 6002 video_file_paths: list, 6003 episode_name: str, # To get the list of behaviors 6004 ground_truth_file_paths: list = None, 6005 pred_thresh: float = 0.5, 6006 start: int = 0, 6007 end: int = -1, 6008 ): 6009 """Create a video with the predictions overlaid on the video""" 6010 for k, (pred_path, vid_path) in enumerate( 6011 zip(prediction_file_paths, video_file_paths) 6012 ): 6013 print("Generating video for :", os.path.basename(vid_path)) 6014 predictions = load_pickle(pred_path) 6015 best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh 6016 behaviors = self.get_behavior_dictionary(episode_name) 6017 # Load video 6018 labels_pred = [behaviors[i] for i in range(len(behaviors))] 6019 labels_pred = np.roll( 6020 labels_pred, 1 6021 ).tolist() 6022 6023 gt_data = None 6024 if ground_truth_file_paths is not None: 6025 gt_data = load_pickle(ground_truth_file_paths[k]) 6026 labels_gt = gt_data[1] 6027 gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1]) 6028 6029 cap = cv2.VideoCapture(vid_path) 6030 cap.set(cv2.CAP_PROP_POS_FRAMES, start) 6031 end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end 6032 fps = cap.get(cv2.CAP_PROP_FPS) 6033 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 6034 height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 6035 fourcc = cv2.VideoWriter_fourcc(*"mp4v") 6036 out = cv2.VideoWriter( 6037 os.path.join( 6038 os.path.dirname(vid_path), 6039 os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4", 6040 ), 6041 fourcc, 6042 fps, 6043 # (width + int(width/4) , height), 6044 (600, 300), 6045 ) 6046 count = 0 6047 bar = tqdm(total=end - start) 6048 while cap.isOpened(): 6049 ret, frame = cap.read() 6050 if not ret: 6051 break 6052 6053 side_panel = self._create_side_panel( 6054 height, 6055 width, 6056 labels_pred, 6057 best_pred[:, count], 6058 labels_gt, 6059 gt_data[:, count], 6060 ) 6061 frame = np.concatenate((frame, side_panel), axis=1) 6062 frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25) 6063 out.write(frame) 6064 count += 1 6065 bar.update(1) 6066 6067 if count > end: 6068 break 6069 6070 cap.release() 6071 out.release() 6072 cv2.destroyAllWindows()
Create a video with the predictions overlaid on the video
6074 def plot_predictions( 6075 self, 6076 episode_name: str, 6077 load_epoch: int = None, 6078 parameters_update: Dict = None, 6079 add_legend: bool = True, 6080 ground_truth: bool = True, 6081 colormap: str = "dlc2action", 6082 hide_axes: bool = False, 6083 min_classes: int = 1, 6084 width: float = 10, 6085 whole_video: bool = False, 6086 transparent: bool = False, 6087 drop_classes: Set = None, 6088 search_classes: Set = None, 6089 num_plots: int = 1, 6090 remove_saved_features: bool = False, 6091 smooth_interval_prediction: int = 0, 6092 data_path: str = None, 6093 file_paths: Set = None, 6094 mode: str = "val", 6095 font_size: float = None, 6096 window_size: int = 400, 6097 ) -> None: 6098 """Visualize random predictions. 6099 6100 Parameters 6101 ---------- 6102 episode_name : str 6103 the name of the episode to load 6104 load_epoch : int, optional 6105 the epoch to load (by default last) 6106 parameters_update : dict, optional 6107 parameter update dictionary 6108 add_legend : bool, default True 6109 if True, legend will be added to the plot 6110 ground_truth : bool, default True 6111 if True, ground truth will be added to the plot 6112 colormap : str, default 'Accent' 6113 the `matplotlib` colormap to use 6114 hide_axes : bool, default True 6115 if `True`, the axes will be hidden on the plot 6116 min_classes : int, default 1 6117 the minimum number of classes in a displayed interval 6118 width : float, default 10 6119 the width of the plot 6120 whole_video : bool, default False 6121 if `True`, whole videos are plotted instead of segments 6122 transparent : bool, default False 6123 if `True`, the background on the plot is transparent 6124 drop_classes : set, optional 6125 a set of class names to not be displayed 6126 search_classes : set, optional 6127 if given, only intervals where at least one of the classes is in ground truth will be shown 6128 num_plots : int, default 1 6129 the number of plots to make 6130 remove_saved_features : bool, default False 6131 if `True`, the dataset will be deleted after computation 6132 smooth_interval_prediction : int, default 0 6133 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) 6134 data_path : str, optional 6135 the data path to run the prediction for 6136 file_paths : set, optional 6137 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 6138 for 6139 mode : {'all', 'test', 'val', 'train'} 6140 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 6141 6142 """ 6143 plot_path = os.path.join(self.project_path, "results", "plots") 6144 task, parameters, mode = self._make_task_prediction( 6145 "_", 6146 load_episode=episode_name, 6147 parameters_update=parameters_update, 6148 load_epoch=load_epoch, 6149 data_path=data_path, 6150 file_paths=file_paths, 6151 mode=mode, 6152 ) 6153 os.makedirs(plot_path, exist_ok=True) 6154 task.visualize_results( 6155 save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"), 6156 add_legend=add_legend, 6157 ground_truth=ground_truth, 6158 colormap=colormap, 6159 hide_axes=hide_axes, 6160 min_classes=min_classes, 6161 whole_video=whole_video, 6162 transparent=transparent, 6163 dataset=mode, 6164 drop_classes=drop_classes, 6165 search_classes=search_classes, 6166 width=width, 6167 smooth_interval_prediction=smooth_interval_prediction, 6168 font_size=font_size, 6169 num_plots=num_plots, 6170 window_size=window_size, 6171 ) 6172 if remove_saved_features: 6173 self._remove_stores(parameters)
Visualize random predictions.
Parameters
episode_name : str
    the name of the episode to load
load_epoch : int, optional
    the epoch to load (by default last)
parameters_update : dict, optional
    parameter update dictionary
add_legend : bool, default True
    if True, legend will be added to the plot
ground_truth : bool, default True
    if True, ground truth will be added to the plot
colormap : str, default 'Accent'
    the matplotlib colormap to use
hide_axes : bool, default True
    if True, the axes will be hidden on the plot
min_classes : int, default 1
    the minimum number of classes in a displayed interval
width : float, default 10
    the width of the plot
whole_video : bool, default False
    if True, whole videos are plotted instead of segments
transparent : bool, default False
    if True, the background on the plot is transparent
drop_classes : set, optional
    a set of class names to not be displayed
search_classes : set, optional
    if given, only intervals where at least one of the classes is in ground truth will be shown
num_plots : int, default 1
    the number of plots to make
remove_saved_features : bool, default False
    if True, the dataset will be deleted after computation
smooth_interval_prediction : int, default 0
    if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
data_path : str, optional
    the data path to run the prediction for
file_paths : set, optional
    a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
    for
mode : {'all', 'test', 'val', 'train'}
    the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
6175 def create_video_from_labels( 6176 self, 6177 video_dir_path: str, 6178 mode="ground_truth", 6179 prediction_name: str = None, 6180 save_path: str = None, 6181 ): 6182 if save_path is None: 6183 save_path = os.path.join( 6184 self.project_path, "results", f"annotated_videos_from_{mode}" 6185 ) 6186 os.makedirs(save_path, exist_ok=True) 6187 6188 params = self._read_parameters(catch_blanks=False) 6189 6190 if mode == "ground_truth": 6191 source_dir = self.annotation_path 6192 annotation_suffix = params["data"]["annotation_suffix"] 6193 elif mode == "prediction": 6194 assert ( 6195 not prediction_name is None 6196 ), "Please provide a prediction name with mode 'prediction'" 6197 source_dir = os.path.join( 6198 self.project_path, "results", "predictions", prediction_name 6199 ) 6200 annotation_suffix = f"_{prediction_name}_prediction.pickle" 6201 6202 video_annotation_pairs = [ 6203 ( 6204 os.path.join(video_dir_path, f), 6205 os.path.join( 6206 source_dir, f.replace(f.split(".")[-1], annotation_suffix) 6207 ), 6208 ) 6209 for f in os.listdir(video_dir_path) 6210 if os.path.exists( 6211 os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix)) 6212 ) 6213 ] 6214 6215 for video_file, annotation_file in tqdm(video_annotation_pairs): 6216 if not os.path.exists(video_file): 6217 print(f"Video file {video_file} does not exist, skipping.") 6218 continue 6219 if not os.path.exists(annotation_file): 6220 print(f"Annotation file {annotation_file} does not exist, skipping.") 6221 continue 6222 6223 if annotation_file.endswith(".pickle"): 6224 annotations = load_pickle(annotation_file) 6225 elif annotation_file.endswith(".csv"): 6226 annotations = pd.read_csv(annotation_file) 6227 6228 if mode == "ground_truth": 6229 behaviors = annotations[1] 6230 annot_data = annotations[3] 6231 elif mode == "predictions": 6232 behaviors = list(annotations["classes"].values()) 6233 annot_data = [ 6234 annotations[key] 6235 for key in annotations.keys() 6236 if key not in ["classes", "min_frame", "max_frame"] 6237 ] 6238 if params["general"]["exclusive"]: 6239 annot_data = [np.argmax(annot, axis=1) for annot in annot_data] 6240 seqs = [ 6241 [ 6242 self._bin_array_to_sequences(annot, target_value=k) 6243 for k in range(len(behaviors)) 6244 ] 6245 for annot in annot_data 6246 ] 6247 else: 6248 annot_data = [np.where(annot > 0.5)[0] for annot in annot_data] 6249 seqs = [ 6250 self._bin_array_to_sequences(annot, target_value=1) 6251 for annot in annot_data 6252 ] 6253 annotations = ["", "", seqs] 6254 6255 for individual in annotations[3]: 6256 for behavior in annotations[3][individual]: 6257 intervals = annotations[3][individual][behavior] 6258 self._extract_videos( 6259 video_file, 6260 intervals, 6261 behavior, 6262 individual, 6263 save_path, 6264 resolution=(640, 480), 6265 fps=30, 6266 )
6327 def create_metadata_backup(self) -> None: 6328 """Create a copy of the meta files.""" 6329 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 6330 meta_path = os.path.join(self.project_path, "meta") 6331 if os.path.exists(meta_copy_path): 6332 shutil.rmtree(meta_copy_path) 6333 os.mkdir(meta_copy_path) 6334 for file in os.listdir(meta_path): 6335 if file == "backup": 6336 continue 6337 if os.path.isdir(os.path.join(meta_path, file)): 6338 continue 6339 shutil.copy( 6340 os.path.join(meta_path, file), os.path.join(meta_copy_path, file) 6341 )
Create a copy of the meta files.
6343 def load_metadata_backup(self) -> None: 6344 """Load from previously created meta data backup (in case of corruption).""" 6345 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 6346 meta_path = os.path.join(self.project_path, "meta") 6347 for file in os.listdir(meta_copy_path): 6348 shutil.copy( 6349 os.path.join(meta_copy_path, file), os.path.join(meta_path, file) 6350 )
Load from previously created meta data backup (in case of corruption).
6352 def get_behavior_dictionary(self, episode_name: str) -> Dict: 6353 """Get the behavior dictionary for an episode. 6354 6355 Parameters 6356 ---------- 6357 episode_name : str 6358 the name of the episode 6359 6360 Returns 6361 ------- 6362 behaviors_dictionary : dict 6363 a dictionary where keys are label indices and values are label names 6364 6365 """ 6366 return self._episode(episode_name).get_behaviors_dict()
Get the behavior dictionary for an episode.
Parameters
episode_name : str the name of the episode
Returns
behaviors_dictionary : dict a dictionary where keys are label indices and values are label names
6368 def import_episodes( 6369 self, 6370 episodes_directory: str, 6371 name_map: Dict = None, 6372 repeat_policy: str = "error", 6373 ) -> None: 6374 """Import episodes exported with `Project.export_episodes`. 6375 6376 Parameters 6377 ---------- 6378 episodes_directory : str 6379 the path to the exported episodes directory 6380 name_map : dict, optional 6381 a name change dictionary for the episodes: keys are old names, values are new names 6382 repeat_policy : {'error', 'skip', 'force'}, default 'error' 6383 the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates, 6384 'force' overwrites existing episodes 6385 6386 """ 6387 if name_map is None: 6388 name_map = {} 6389 episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle")) 6390 to_remove = [] 6391 import_string = "Imported episodes: " 6392 for episode_name in episodes.index: 6393 if episode_name in name_map: 6394 import_string += f"{episode_name} " 6395 episode_name = name_map[episode_name] 6396 import_string += f"({episode_name}), " 6397 else: 6398 import_string += f"{episode_name}, " 6399 try: 6400 self._check_episode_validity(episode_name, allow_doublecolon=True) 6401 except ValueError as e: 6402 if str(e).endswith("is already taken!"): 6403 if repeat_policy == "skip": 6404 to_remove.append(episode_name) 6405 elif repeat_policy == "force": 6406 self.remove_episode(episode_name) 6407 elif repeat_policy == "error": 6408 raise ValueError( 6409 f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it" 6410 ) 6411 else: 6412 raise ValueError( 6413 f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']" 6414 ) 6415 episodes = episodes.drop(index=to_remove) 6416 self._episodes().update( 6417 episodes, 6418 name_map=name_map, 6419 force=(repeat_policy == "force"), 6420 data_path=self.data_path, 6421 annotation_path=self.annotation_path, 6422 ) 6423 for episode_name in episodes.index: 6424 if episode_name in name_map: 6425 new_episode_name = name_map[episode_name] 6426 else: 6427 new_episode_name = episode_name 6428 model_dir = os.path.join( 6429 self.project_path, "results", "model", new_episode_name 6430 ) 6431 old_model_dir = os.path.join(episodes_directory, "model", episode_name) 6432 if os.path.exists(model_dir): 6433 shutil.rmtree(model_dir) 6434 os.mkdir(model_dir) 6435 for file in os.listdir(old_model_dir): 6436 shutil.copyfile( 6437 os.path.join(old_model_dir, file), os.path.join(model_dir, file) 6438 ) 6439 log_file = os.path.join( 6440 self.project_path, "results", "logs", f"{new_episode_name}.txt" 6441 ) 6442 old_log_file = os.path.join( 6443 episodes_directory, "logs", f"{episode_name}.txt" 6444 ) 6445 shutil.copyfile(old_log_file, log_file) 6446 print(import_string) 6447 print("\n")
Import episodes exported with Project.export_episodes.
Parameters
episodes_directory : str the path to the exported episodes directory name_map : dict, optional a name change dictionary for the episodes: keys are old names, values are new names repeat_policy : {'error', 'skip', 'force'}, default 'error' the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates, 'force' overwrites existing episodes
6449 def export_episodes( 6450 self, episode_names: List, output_directory: str, name: str = None 6451 ) -> None: 6452 """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`. 6453 6454 Parameters 6455 ---------- 6456 episode_names : list 6457 a list of string episode names 6458 output_directory : str 6459 the path to the directory where the episodes will be saved 6460 name : str, optional 6461 the name of the episodes directory (by default `exported_episodes`) 6462 6463 """ 6464 if name is None: 6465 name = "exported_episodes" 6466 if os.path.exists( 6467 os.path.join(output_directory, name + ".zip") 6468 ) or os.path.exists(os.path.join(output_directory, name)): 6469 i = 1 6470 while os.path.exists( 6471 os.path.join(output_directory, name + f"_{i}.zip") 6472 ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")): 6473 i += 1 6474 name = name + f"_{i}" 6475 dest_dir = os.path.join(output_directory, name) 6476 os.mkdir(dest_dir) 6477 os.mkdir(os.path.join(dest_dir, "model")) 6478 os.mkdir(os.path.join(dest_dir, "logs")) 6479 runs = [] 6480 for episode in episode_names: 6481 runs += self._episodes().get_runs(episode) 6482 for run in runs: 6483 shutil.copytree( 6484 os.path.join(self.project_path, "results", "model", run), 6485 os.path.join(dest_dir, "model", run), 6486 ) 6487 shutil.copyfile( 6488 os.path.join(self.project_path, "results", "logs", f"{run}.txt"), 6489 os.path.join(dest_dir, "logs", f"{run}.txt"), 6490 ) 6491 data = self._episodes().get_subset(runs) 6492 data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))
Save selected episodes as a file that can be imported into another project with Project.import_episodes.
Parameters
episode_names : list
    a list of string episode names
output_directory : str
    the path to the directory where the episodes will be saved
name : str, optional
    the name of the episodes directory (by default exported_episodes)
6494 def get_results_table( 6495 self, 6496 episode_names: List, 6497 metrics: List = None, 6498 mode: str = "mean", # Choose between ["mean", "statistics", "detail"] 6499 print_results: bool = True, 6500 classes: List = None, 6501 ): 6502 """Generate a `pandas` dataframe with a summary of episode results. 6503 6504 Parameters 6505 ---------- 6506 episode_names : list 6507 a list of names of episodes to include 6508 metrics : list, optional 6509 a list of metric names to include 6510 mode : bool, optional 6511 the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean" 6512 print_results : bool, optional 6513 if True, the results will be printed to the console, by default True 6514 classes : list, optional 6515 a list of names of classes to include (by default all are included) 6516 6517 Returns 6518 ------- 6519 results : pd.DataFrame 6520 a table with the results 6521 6522 """ 6523 run_names = [] 6524 for episode in episode_names: 6525 run_names += self._episodes().get_runs(episode) 6526 episodes = self.list_episodes(run_names, print_results=False) 6527 metric_columns = [x for x in episodes.columns if x[0] == "results"] 6528 results_df = pd.DataFrame() 6529 if metrics is not None: 6530 metric_columns = [ 6531 x for x in metric_columns if x[1].split("_")[0] in metrics 6532 ] 6533 for episode in episode_names: 6534 results = [] 6535 metric_set = set() 6536 for run in self._episodes().get_runs(episode): 6537 beh_dict = self.get_behavior_dictionary(run) 6538 res_dict = defaultdict(lambda: {}) 6539 for column in metric_columns: 6540 if np.isnan(episodes.loc[run, column]): 6541 continue 6542 split = column[1].split("_") 6543 if split[-1].isnumeric(): 6544 beh_ind = int(split[-1]) 6545 metric_name = "_".join(split[:-1]) 6546 beh = beh_dict[beh_ind] 6547 else: 6548 beh = "average" 6549 metric_name = column[1] 6550 res_dict[beh][metric_name] = episodes.loc[run, column] 6551 metric_set.add(metric_name) 6552 if "average" not in res_dict: 6553 res_dict["average"] = {} 6554 for metric in metric_set: 6555 if metric not in res_dict["average"]: 6556 arr = [ 6557 res_dict[beh][metric] 6558 for beh in res_dict 6559 if metric in res_dict[beh] 6560 ] 6561 res_dict["average"][metric] = np.mean(arr) 6562 results.append(res_dict) 6563 episode_results = {} 6564 for metric in metric_set: 6565 for beh in results[0].keys(): 6566 if classes is not None and beh not in classes: 6567 continue 6568 arr = [] 6569 for res_dict in results: 6570 if metric in res_dict[beh]: 6571 arr.append(res_dict[beh][metric]) 6572 if len(arr) > 0: 6573 if mode == "statistics": 6574 episode_results[(beh, f"{episode} {metric} mean")] = ( 6575 np.mean(arr) 6576 ) 6577 episode_results[(beh, f"{episode} {metric} std")] = np.std( 6578 arr 6579 ) 6580 elif mode == "mean": 6581 episode_results[(beh, f"{episode} {metric}")] = np.mean(arr) 6582 elif mode == "detail": 6583 for i, val in enumerate(arr): 6584 episode_results[(beh, f"{episode}::{i} {metric}")] = val 6585 for key, value in episode_results.items(): 6586 results_df.loc[key[0], key[1]] = value 6587 if print_results: 6588 print(f"RESULTS:") 6589 print(results_df) 6590 print("\n") 6591 return results_df
Generate a pandas dataframe with a summary of episode results.
Parameters
episode_names : list a list of names of episodes to include metrics : list, optional a list of metric names to include mode : bool, optional the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean" print_results : bool, optional if True, the results will be printed to the console, by default True classes : list, optional a list of names of classes to include (by default all are included)
Returns
results : pd.DataFrame a table with the results
6593 def episode_exists(self, episode_name: str) -> bool: 6594 """Check if an episode already exists. 6595 6596 Parameters 6597 ---------- 6598 episode_name : str 6599 the episode name 6600 6601 Returns 6602 ------- 6603 exists : bool 6604 `True` if the episode exists 6605 6606 """ 6607 return self._episodes().check_name_validity(episode_name)
Check if an episode already exists.
Parameters
episode_name : str the episode name
Returns
exists : bool
    True if the episode exists
6609 def search_exists(self, search_name: str) -> bool: 6610 """Check if a search already exists. 6611 6612 Parameters 6613 ---------- 6614 search_name : str 6615 the search name 6616 6617 Returns 6618 ------- 6619 exists : bool 6620 `True` if the search exists 6621 6622 """ 6623 return self._searches().check_name_validity(search_name)
Check if a search already exists.
Parameters
search_name : str the search name
Returns
exists : bool
    True if the search exists
6625 def prediction_exists(self, prediction_name: str) -> bool: 6626 """Check if a prediction already exists. 6627 6628 Parameters 6629 ---------- 6630 prediction_name : str 6631 the prediction name 6632 6633 Returns 6634 ------- 6635 exists : bool 6636 `True` if the prediction exists 6637 6638 """ 6639 return self._predictions().check_name_validity(prediction_name)
Check if a prediction already exists.
Parameters
prediction_name : str the prediction name
Returns
exists : bool
    True if the prediction exists
6641 @staticmethod 6642 def project_name_available(projects_path: str, project_name: str): 6643 """Check if a project name is available. 6644 6645 Parameters 6646 ---------- 6647 projects_path : str 6648 the path to the projects directory 6649 project_name : str 6650 the name of the project to check 6651 6652 Returns 6653 ------- 6654 available : bool 6655 `True` if the project name is available 6656 6657 """ 6658 if projects_path is None: 6659 projects_path = os.path.join(str(Path.home()), "DLC2Action") 6660 return not os.path.exists(os.path.join(projects_path, project_name))
Check if a project name is available.
Parameters
projects_path : str the path to the projects directory project_name : str the name of the project to check
Returns
available : bool
    True if the project name is available
6675 def rename_episode(self, episode_name: str, new_episode_name: str): 6676 """Rename an episode. 6677 6678 Parameters 6679 ---------- 6680 episode_name : str 6681 the current episode name 6682 new_episode_name : str 6683 the new episode name 6684 6685 """ 6686 shutil.move( 6687 os.path.join(self.project_path, "results", "model", episode_name), 6688 os.path.join(self.project_path, "results", "model", new_episode_name), 6689 ) 6690 shutil.move( 6691 os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"), 6692 os.path.join( 6693 self.project_path, "results", "logs", f"{new_episode_name}.txt" 6694 ), 6695 ) 6696 self._episodes().rename_episode(episode_name, new_episode_name)
Rename an episode.
Parameters
episode_name : str the current episode name new_episode_name : str the new episode name