dlc2action.project.project
Project interface
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Project interface 8""" 9import copy 10from email.policy import default 11import os 12from re import search 13from typing import Dict, List, Tuple, Union, Set, Iterable, Any, Optional 14import shutil 15 16from numpy import ndarray 17from ruamel.yaml import YAML 18import pickle 19import pandas as pd 20from dlc2action.data.dataset import BehaviorDataset 21from dlc2action.utils import apply_threshold 22from collections.abc import Mapping 23from collections import defaultdict 24 25from dlc2action.task.task_dispatcher import TaskDispatcher 26import warnings 27from copy import deepcopy, copy 28import time 29import numpy as np 30from matplotlib import pyplot as plt 31from matplotlib import cm 32from itertools import product 33from collections.abc import Iterable 34import optuna 35import plotly 36import torch 37from pathlib import Path 38from dlc2action import options, __version__ 39from ruamel.yaml.comments import CommentedMap, CommentedSet 40from tqdm import tqdm 41from dlc2action.project.meta import ( 42 Searches, 43 SavedStores, 44 Run, 45 SavedRuns, 46 DecisionThresholds, 47) 48 49 50class Project: 51 """ 52 A class to create and maintain the project files + keep track of experiments 53 """ 54 55 def __init__( 56 self, 57 name: str, 58 data_type: str = None, 59 annotation_type: str = "none", 60 projects_path: str = None, 61 data_path: Union[str, List] = None, 62 annotation_path: Union[str, List] = None, 63 copy: bool = False, 64 ) -> None: 65 """ 66 Parameters 67 ---------- 68 name : str 69 name of the project 70 data_type : str, optional 71 data type (run Project.data_types() to see available options; has to be provided if the project is being 72 created) 73 annotation_type : str, default 'none' 74 annotation type (run Project.annotation_types() to see available options) 75 projects_path : str, optional 76 path to the projects folder (is filled with ~/DLC2Action by default) 77 data_path : str, optional 78 path to the folder containing input files for the project (has to be provided if the project is being 79 created) 80 annotation_path : str, optional 81 path to the folder containing annotation files for the project 82 copy : bool, default False 83 if True, the files from annotation_path and data_path will be copied to the projects folder; 84 otherwise they will be moved 85 """ 86 87 if projects_path is None: 88 projects_path = os.path.join(str(Path.home()), "DLC2Action") 89 if not os.path.exists(projects_path): 90 os.mkdir(projects_path) 91 self.project_path = os.path.join(projects_path, name) 92 self.name = name 93 self.data_type = data_type 94 self.annotation_type = annotation_type 95 self.data_path = data_path 96 self.annotation_path = annotation_path 97 if not os.path.exists(self.project_path): 98 if data_type is None: 99 raise ValueError( 100 "The data_type parameter is necessary when creating a new project!" 101 ) 102 self._initialize_project( 103 data_type, annotation_type, data_path, annotation_path, copy 104 ) 105 else: 106 self.annotation_type, self.data_type = self._read_types() 107 if data_type != self.data_type and data_type is not None: 108 raise ValueError( 109 f"The project has already been initialized with data_type={self.data_type}!" 110 ) 111 if annotation_type != self.annotation_type and annotation_type != "none": 112 raise ValueError( 113 f"The project has already been initialized with annotation_type={self.annotation_type}!" 114 ) 115 self.annotation_path, data_path = self._read_paths() 116 if self.data_path is None: 117 self.data_path = data_path 118 # if data_path != self.data_path and data_path is not None: 119 # raise ValueError( 120 # f"The project has already been initialized with data_path={self.data_path}!" 121 # ) 122 if annotation_path != self.annotation_path and annotation_path is not None: 123 raise ValueError( 124 f"The project has already been initialized with annotation_path={self.annotation_path}!" 125 ) 126 self._update_configs() 127 128 def _aggregate_predictions( 129 self, 130 prediction_name: str, 131 episode_names: List, 132 load_epochs: List = None, 133 parameters_update: Dict = None, 134 data_path: str = None, 135 file_paths: Set = None, 136 mode: str = "all", 137 augment_n: int = 0, 138 evaluate: bool = False, 139 task: TaskDispatcher = None, 140 embedding: bool = False, 141 ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]: 142 """ 143 Generate a prediction 144 """ 145 146 if load_epochs is None: 147 load_epochs = [None for _ in episode_names] 148 prediction = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) 149 cnt = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) 150 behs = set(self.get_behavior_dictionary(episode_names[0]).values()) 151 if not all( 152 [ 153 set(self.get_behavior_dictionary(x).values()) == behs 154 for x in episode_names 155 ] 156 ): 157 raise ValueError(f"The behavior sets are different in {episode_names}") 158 behaviors = set() 159 for i, episode_name in enumerate(episode_names): 160 task, parameters, data_mode, new_pred, _ = self._make_prediction( 161 prediction_name, 162 episode_names=[episode_name], 163 load_epochs=[load_epochs[i]], 164 parameters_update=parameters_update, 165 data_path=data_path, 166 file_paths=file_paths, 167 mode=mode, 168 augment_n=augment_n, 169 evaluate=evaluate, 170 task=task, 171 embedding=embedding, 172 ) 173 new_pred = task.dataset(data_mode).generate_full_length_prediction(new_pred) 174 beh_dict = task.behaviors_dict() 175 for video_id, video_values in new_pred.items(): 176 for clip_id, clip_prediction in video_values.items(): 177 for beh_i in range(clip_prediction.shape[0]): 178 prediction[video_id][clip_id][ 179 beh_dict[beh_i] 180 ] += clip_prediction[beh_i, :].unsqueeze(0) 181 cnt[video_id][clip_id][beh_dict[beh_i]] += 1 182 behaviors.add(beh_dict[beh_i]) 183 output = defaultdict(lambda: {}) 184 # behaviors = sorted(behaviors) 185 behavior_indices = sorted( 186 [x for x in task.behaviors_dict().keys() if x != -100] 187 ) 188 behaviors = [task.behaviors_dict()[key] for key in behavior_indices] 189 for video_id, video_values in prediction.items(): 190 for clip_id, clip_values in video_values.items(): 191 pred = torch.cat( 192 [ 193 clip_values[beh] / cnt[video_id][clip_id][beh] 194 for beh in behaviors 195 ], 196 0, 197 ) 198 output[video_id][clip_id] = pred 199 return task, parameters, data_mode, dict(output), None 200 201 def _make_prediction( 202 self, 203 prediction_name: str, 204 episode_names: List, 205 load_epochs: List = None, 206 parameters_update: Dict = None, 207 data_path: str = None, 208 file_paths: Set = None, 209 mode: str = "all", 210 augment_n: int = 0, 211 evaluate: bool = False, 212 task: TaskDispatcher = None, 213 embedding: bool = False, 214 ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]: 215 """ 216 Generate a prediction 217 """ 218 219 names = [] 220 epochs = [] 221 if load_epochs is None: 222 load_epochs = [None for _ in episode_names] 223 if len(load_epochs) != len(episode_names): 224 raise ValueError( 225 f"The length of load_epochs and the length of episode_names should be the same!" 226 ) 227 for i, episode_name in enumerate(episode_names): 228 names += self._episodes().get_runs(episode_name) 229 epochs.append(load_epochs[i]) 230 if len(names) == 0: 231 warnings.warn(f"None of the episodes {episode_names} exist!") 232 names = [None] 233 episodes = self._episodes() 234 lengths = [ 235 episodes.load_parameters(name)["general"]["len_segment"] for name in names 236 ] 237 overlaps = [ 238 episodes.load_parameters(name)["general"]["overlap"] for name in names 239 ] 240 if not all([x == lengths[0] for x in lengths]): 241 raise ValueError(f"Episodes {episode_names} have different segment lengths") 242 if not all([x == overlaps[0] for x in overlaps]): 243 raise ValueError(f"Episodes {episode_names} have different overlaps") 244 load_epochs = epochs 245 prediction = None 246 decision_thresholds = None 247 time_total = 0 248 behavior_dicts = [ 249 self.get_behavior_dictionary(episode_name) for episode_name in names 250 ] 251 if not all( 252 [ 253 set(d.values()) == set(behavior_dicts[0].values()) 254 for d in behavior_dicts[1:] 255 ] 256 ): 257 raise ValueError( 258 f"Episodes {episode_names} have different sets of behaviors!" 259 ) 260 behavior_indices = [x for x in behavior_dicts[0].keys() if x != -100] 261 behaviors = [behavior_dicts[0][i] for i in behavior_indices] 262 cnt = defaultdict(lambda: 0) 263 behavior_probs = defaultdict(lambda: 0) 264 for episode_name, load_epoch, behavior_dict in zip( 265 names, load_epochs, behavior_dicts 266 ): 267 print(f"episode {episode_name}") 268 task, parameters, data_mode = self._make_task_prediction( 269 prediction_name=prediction_name, 270 load_episode=episode_name, 271 parameters_update=parameters_update, 272 load_epoch=load_epoch, 273 data_path=data_path, 274 mode=mode, 275 file_paths=file_paths, 276 task=task, 277 decision_thresholds=decision_thresholds, 278 ) 279 behavior_indices_cur = [x for x in behavior_dict.keys() if x != -100] 280 behaviors_cur = [behavior_dict[i] for i in behavior_indices_cur] 281 # data_mode = "train" if mode == "all" else mode 282 time_start = time.time() 283 new_pred = task.predict( 284 data_mode, 285 raw_output=True, 286 apply_primary_function=True, 287 augment_n=augment_n, 288 embedding=embedding, 289 ) 290 for j, beh in enumerate(behaviors_cur): 291 cnt[beh] += 1 292 behavior_probs[beh] += new_pred[:, j, :].unsqueeze(1) 293 # indices = [ 294 # behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1]) 295 # ] 296 # new_pred = new_pred[:, indices, :] 297 time_end = time.time() 298 time_total += time_end - time_start 299 if evaluate: 300 _, metrics = task.evaluate_prediction(new_pred, data=data_mode) 301 if mode == "val": 302 self._update_episode_metrics(episode_name, metrics) 303 # if prediction is None: 304 # prediction = new_pred 305 # else: 306 # prediction += new_pred 307 print("\n") 308 prediction = torch.cat([behavior_probs[beh] / cnt[beh] for beh in behaviors], 1) 309 hours = int(time_total // 3600) 310 time_total -= hours * 3600 311 minutes = int(time_total // 60) 312 time_total -= minutes * 60 313 seconds = int(time_total) 314 inference_time = f"{hours}:{minutes:02}:{seconds:02}" 315 # prediction /= len(names) 316 return task, parameters, data_mode, prediction, inference_time 317 318 def _make_task_prediction( 319 self, 320 prediction_name: str, 321 load_episode: str = None, 322 parameters_update: Dict = None, 323 load_epoch: int = None, 324 data_path: str = None, 325 mode: str = "val", 326 file_paths: Set = None, 327 decision_thresholds: List = None, 328 task: TaskDispatcher = None, 329 ) -> Tuple[TaskDispatcher, Dict, str]: 330 """ 331 Make a `TaskDispatcher` object that will be used to generate a prediction 332 """ 333 334 if parameters_update is None: 335 parameters_update = {} 336 parameters_update_second = {} 337 if mode == "all" or data_path is not None or file_paths is not None: 338 parameters_update_second["training"] = { 339 "val_frac": 0, 340 "test_frac": 0, 341 "partition_method": "random", 342 "save_split": False, 343 "split_path": None, 344 } 345 mode = "train" 346 if decision_thresholds is not None: 347 if ( 348 len(decision_thresholds) 349 == self._episode(load_episode).get_num_classes() 350 ): 351 parameters_update_second["general"] = { 352 "threshold_value": decision_thresholds 353 } 354 else: 355 raise ValueError( 356 f"The length of the decision thresholds {decision_thresholds} " 357 f"must be equal to the length of the behaviors dictionary " 358 f"{self._episode(load_episode).get_behaviors_dict()}" 359 ) 360 data_param_update = {} 361 if data_path is not None: 362 data_param_update = {"data_path": data_path} 363 if file_paths is not None: 364 data_param_update = {"data_path": None, "file_paths": file_paths} 365 parameters_update = self._update(parameters_update, {"data": data_param_update}) 366 if data_path is not None or file_paths is not None: 367 general_update = { 368 "annotation_type": "none", 369 "only_load_annotated": False, 370 } 371 else: 372 general_update = {} 373 parameters_update = self._update(parameters_update, {"general": general_update}) 374 task, parameters = self._make_task( 375 episode_name=prediction_name, 376 load_episode=load_episode, 377 parameters_update=parameters_update, 378 parameters_update_second=parameters_update_second, 379 load_epoch=load_epoch, 380 purpose="prediction", 381 task=task, 382 behaviors=self.get_behavior_dictionary(load_episode), 383 ) 384 # if data_path is not None or file_paths is not None: 385 # print('SETTING') 386 # task.set_behaviors(self.get_behavior_dictionary(load_episode)) 387 if mode is None: 388 if task.exists("test"): 389 mode = "test" 390 elif task.exists("val"): 391 mode = "val" 392 else: 393 mode = "train" 394 return task, parameters, mode 395 396 def _make_task_training( 397 self, 398 episode_name: str, 399 load_episode: str = None, 400 parameters_update: Dict = None, 401 load_epoch: int = None, 402 load_search: str = None, 403 load_parameters: list = None, 404 round_to_binary: list = None, 405 load_strict: bool = True, 406 continuing: bool = False, 407 task: TaskDispatcher = None, 408 mask_name: str = None, 409 throwaway: bool = False, 410 ) -> Tuple[TaskDispatcher, Dict, str]: 411 """ 412 Make a `TaskDispatcher` object that will be used to generate a prediction 413 """ 414 415 if parameters_update is None: 416 parameters_update = {} 417 if continuing: 418 purpose = "continuing" 419 else: 420 purpose = "training" 421 if mask_name is not None: 422 mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle") 423 parameters_update_second = {"data": {"real_lens": mask_name}} 424 if throwaway: 425 parameters_update = self._update( 426 parameters_update, {"training": {"normalize": False, "device": "cpu"}} 427 ) 428 return self._make_task( 429 episode_name, 430 load_episode, 431 parameters_update, 432 parameters_update_second, 433 load_epoch, 434 load_search, 435 load_parameters, 436 round_to_binary, 437 purpose, 438 task, 439 load_strict=load_strict, 440 ) 441 442 def _make_parameters( 443 self, 444 episode_name: str, 445 load_episode: str = None, 446 parameters_update: Dict = None, 447 parameters_update_second: Dict = None, 448 load_epoch: int = None, 449 load_search: str = None, 450 load_parameters: list = None, 451 round_to_binary: list = None, 452 purpose: str = "train", 453 load_strict: bool = True, 454 ): 455 """ 456 Construct a parameters dictionary 457 """ 458 459 if parameters_update is None: 460 parameters_update = {} 461 pars_update = deepcopy(parameters_update) 462 if parameters_update_second is None: 463 parameters_update_second = {} 464 if purpose == "prediction" and "model" in pars_update.keys(): 465 raise ValueError("Cannot change model parameters after training!") 466 if purpose in ["continuing", "prediction"] and load_episode is not None: 467 read_parameters = self._read_parameters() 468 parameters = self._episodes().load_parameters(load_episode) 469 parameters["metrics"] = self._update( 470 read_parameters["metrics"], parameters["metrics"] 471 ) 472 parameters["ssl"] = self._update( 473 read_parameters["ssl"], parameters.get("ssl", {}) 474 ) 475 else: 476 parameters = self._read_parameters() 477 if "model" in pars_update: 478 model_params = pars_update.pop("model") 479 else: 480 model_params = None 481 if "features" in pars_update: 482 feat_params = pars_update.pop("features") 483 else: 484 feat_params = None 485 if "augmentations" in pars_update: 486 aug_params = pars_update.pop("augmentations") 487 else: 488 aug_params = None 489 parameters = self._update(parameters, pars_update) 490 if pars_update.get("general", {}).get("model_name") is not None: 491 model_name = parameters["general"]["model_name"] 492 parameters["model"] = self._open_yaml( 493 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 494 ) 495 if pars_update.get("general", {}).get("feature_extraction") is not None: 496 feat_name = parameters["general"]["feature_extraction"] 497 parameters["features"] = self._open_yaml( 498 os.path.join( 499 self.project_path, "config", "features", f"{feat_name}.yaml" 500 ) 501 ) 502 aug_name = options.extractor_to_transformer[ 503 parameters["general"]["feature_extraction"] 504 ] 505 parameters["augmentations"] = self._open_yaml( 506 os.path.join( 507 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 508 ) 509 ) 510 if model_params is not None: 511 parameters["model"] = self._update(parameters["model"], model_params) 512 if feat_params is not None: 513 parameters["features"] = self._update(parameters["features"], feat_params) 514 if aug_params is not None: 515 parameters["augmentations"] = self._update( 516 parameters["augmentations"], aug_params 517 ) 518 if load_search is not None: 519 parameters = self._update_with_search( 520 parameters, load_search, load_parameters, round_to_binary 521 ) 522 parameters = self._fill( 523 parameters, 524 episode_name, 525 load_episode, 526 load_epoch=load_epoch, 527 load_strict=load_strict, 528 only_load_model=(purpose != "continuing"), 529 continuing=(purpose in ["prediction", "continuing"]), 530 enforce_split_parameters=(purpose == "prediction"), 531 ) 532 parameters = self._update(parameters, parameters_update_second) 533 return parameters 534 535 def _make_task( 536 self, 537 episode_name: str, 538 load_episode: str = None, 539 parameters_update: Dict = None, 540 parameters_update_second: Dict = None, 541 load_epoch: int = None, 542 load_search: str = None, 543 load_parameters: list = None, 544 round_to_binary: list = None, 545 purpose: str = "train", 546 task: TaskDispatcher = None, 547 load_strict: bool = True, 548 behaviors: Dict = None, 549 ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]: 550 """ 551 Make a `TaskDispatcher` object 552 553 The task parameters are read from the config files and then updated with the 554 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 555 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 556 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 557 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 558 data parameters are used. 559 560 Parameters 561 ---------- 562 episode_name : str 563 the name of the episode 564 load_episode : str, optional 565 the (previously run) episode name to load the model from 566 parameters_update : dict, optional 567 the dictionary used to update the parameters from the config 568 parameters_update_second : dict, optional 569 the dictionary used to update the parameters after the automatic fill-out 570 load_epoch : int, optional 571 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 572 load_search : str, optional 573 the hyperparameter search result to load 574 load_parameters : list, optional 575 a list of string names of the parameters to load from load_search (if not provided, all parameters 576 are loaded) 577 round_to_binary : list, optional 578 a list of string names of the loaded parameters that should be rounded to the nearest power of two 579 purpose : {"train", "continuing", "prediction"} 580 the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing 581 the training of an interrupted episode, `"prediction"` for generating a prediction) 582 task : TaskDispatcher, optional 583 a pre-existing task; if provided, the method will update the task instead of creating a new one 584 (this might save time, mainly on dataset loading) 585 586 Returns 587 ------- 588 task : TaskDispatcher 589 the `TaskDispatcher` instance 590 parameters : dict 591 the parameters dictionary that describes the task 592 """ 593 594 parameters = self._make_parameters( 595 episode_name, 596 load_episode, 597 parameters_update, 598 parameters_update_second, 599 load_epoch, 600 load_search, 601 load_parameters, 602 round_to_binary, 603 purpose, 604 load_strict=load_strict, 605 ) 606 if parameters["data"].get("annotation_type", "none") == "none": 607 parameters = self._update( 608 parameters, {"data": {"behavior_dictionary": behaviors}} 609 ) 610 if task is None: 611 task = TaskDispatcher(parameters) 612 else: 613 task.update_task(parameters) 614 self._save_stores(parameters) 615 return task, parameters 616 617 def run_episode( 618 self, 619 episode_name: str, 620 load_episode: str = None, 621 parameters_update: Dict = None, 622 task: TaskDispatcher = None, 623 load_epoch: int = None, 624 load_search: str = None, 625 load_parameters: list = None, 626 round_to_binary: list = None, 627 load_strict: bool = True, 628 n_seeds: int = 1, 629 force: bool = False, 630 suppress_name_check: bool = False, 631 remove_saved_features: bool = False, 632 mask_name: str = None, 633 autostop_metric: str = None, 634 autostop_interval: int = 50, 635 autostop_threshold: float = 0.001, 636 loading_bar: bool = False, 637 trial: Tuple = None, 638 ) -> TaskDispatcher: 639 """ 640 Run an episode 641 642 The task parameters are read from the config files and then updated with the 643 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 644 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 645 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 646 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 647 data parameters are used. 648 649 You can use the autostop parameters to finish training when the parameters are not improving. It will be 650 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 651 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 652 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 653 654 Parameters 655 ---------- 656 episode_name : str 657 the episode name 658 load_episode : str, optional 659 the (previously run) episode name to load the model from; if the episode has multiple runs, 660 the new episode will have the same number of runs, each starting with one of the pre-trained models 661 parameters_update : dict, optional 662 the dictionary used to update the parameters from the config files 663 task : TaskDispatcher, optional 664 a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating 665 a new instance) 666 load_epoch : int, optional 667 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 668 load_search : str, optional 669 the hyperparameter search result to load 670 load_parameters : list, optional 671 a list of string names of the parameters to load from load_search (if not provided, all parameters 672 are loaded) 673 round_to_binary : list, optional 674 a list of string names of the loaded parameters that should be rounded to the nearest power of two 675 load_strict : bool, default True 676 if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and 677 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` 678 n_seeds : int, default 1 679 the number of runs to perform with different random seeds; if `n_seeds > 1`, the episodes will be named 680 `episode_name::seed_index`, e.g. `test_episode::0` and `test_episode::1` 681 force : bool, default False 682 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 683 suppress_name_check : bool, default False 684 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 685 why they are usually forbidden) 686 remove_saved_features : bool, default False 687 if `True`, the dataset will be deleted after training 688 mask_name : str, optional 689 the name of the real_lens to apply 690 autostop_interval : int, default 50 691 the number of epochs to average the autostop metric over 692 autostop_threshold : float, default 0.001 693 the autostop difference threshold 694 autostop_metric : str, optional 695 the autostop metric (can be any one of the tracked metrics of `'loss'`) 696 """ 697 698 if type(n_seeds) is not int or n_seeds < 1: 699 raise ValueError( 700 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 701 ) 702 if n_seeds > 1 and mask_name is not None: 703 raise ValueError("Cannot apply a real_lens with n_seeds > 1") 704 self._check_episode_validity( 705 episode_name, allow_doublecolon=suppress_name_check, force=force 706 ) 707 load_runs = self._episodes().get_runs(load_episode) 708 if len(load_runs) > 1: 709 task = self.run_episodes( 710 episode_names=[ 711 f'{episode_name}::{run.split("::")[-1]}' for run in load_runs 712 ], 713 load_episodes=load_runs, 714 parameters_updates=[parameters_update for _ in load_runs], 715 load_epochs=[load_epoch for _ in load_runs], 716 load_searches=[load_search for _ in load_runs], 717 load_parameters=[load_parameters for _ in load_runs], 718 round_to_binary=[round_to_binary for _ in load_runs], 719 load_strict=[load_strict for _ in load_runs], 720 suppress_name_check=True, 721 force=force, 722 remove_saved_features=False, 723 ) 724 if remove_saved_features: 725 self._remove_stores( 726 { 727 "general": task.general_parameters, 728 "data": task.data_parameters, 729 "features": task.feature_parameters, 730 } 731 ) 732 if n_seeds > 1: 733 warnings.warn( 734 f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs" 735 ) 736 elif n_seeds > 1: 737 self.run_episodes( 738 episode_names=[f"{episode_name}::{i}" for i in range(n_seeds)], 739 load_episodes=[load_episode for _ in range(n_seeds)], 740 parameters_updates=[parameters_update for _ in range(n_seeds)], 741 load_epochs=[load_epoch for _ in range(n_seeds)], 742 load_searches=[load_search for _ in range(n_seeds)], 743 load_parameters=[load_parameters for _ in range(n_seeds)], 744 round_to_binary=[round_to_binary for _ in range(n_seeds)], 745 load_strict=[load_strict for _ in range(n_seeds)], 746 suppress_name_check=True, 747 force=force, 748 remove_saved_features=remove_saved_features, 749 ) 750 else: 751 print(f"TRAINING {episode_name}") 752 try: 753 task, parameters = self._make_task_training( 754 episode_name, 755 load_episode, 756 parameters_update, 757 load_epoch, 758 load_search, 759 load_parameters, 760 round_to_binary, 761 continuing=False, 762 task=task, 763 mask_name=mask_name, 764 load_strict=load_strict, 765 ) 766 self._save_episode( 767 episode_name, 768 parameters, 769 task.behaviors_dict(), 770 norm_stats=task.get_normalization_stats(), 771 ) 772 time_start = time.time() 773 if trial is not None: 774 trial, metric = trial 775 else: 776 trial, metric = None, None 777 logs = task.train( 778 autostop_metric=autostop_metric, 779 autostop_interval=autostop_interval, 780 autostop_threshold=autostop_threshold, 781 loading_bar=loading_bar, 782 trial=trial, 783 optimized_metric=metric, 784 ) 785 time_end = time.time() 786 time_total = time_end - time_start 787 hours = int(time_total // 3600) 788 time_total -= hours * 3600 789 minutes = int(time_total // 60) 790 time_total -= minutes * 60 791 seconds = int(time_total) 792 training_time = f"{hours}:{minutes:02}:{seconds:02}" 793 self._update_episode_results(episode_name, logs, training_time) 794 if remove_saved_features: 795 self._remove_stores(parameters) 796 print("\n") 797 return task 798 799 except Exception as e: 800 if isinstance(e, optuna.exceptions.TrialPruned): 801 raise e 802 else: 803 # if str(e) != f"The {episode_name} episode name is already in use!": 804 # self.remove_episode(episode_name) 805 raise RuntimeError(f"Episode {episode_name} could not run") 806 807 def run_episodes( 808 self, 809 episode_names: List, 810 load_episodes: List = None, 811 parameters_updates: List = None, 812 load_epochs: List = None, 813 load_searches: List = None, 814 load_parameters: List = None, 815 round_to_binary: List = None, 816 load_strict: List = None, 817 force: bool = False, 818 suppress_name_check: bool = False, 819 remove_saved_features: bool = False, 820 ) -> TaskDispatcher: 821 """ 822 Run multiple episodes in sequence (and re-use previously loaded information) 823 824 For each episode, the task parameters are read from the config files and then updated with the 825 parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the 826 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 827 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 828 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 829 data parameters are used. 830 831 Parameters 832 ---------- 833 episode_names : list 834 a list of strings of episode names 835 load_episodes : list, optional 836 a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, 837 the new episode will have the same number of runs, each starting with one of the pre-trained models 838 parameters_updates : list, optional 839 a list of dictionaries used to update the parameters from the config 840 load_epochs : list, optional 841 a list of integers used to specify the epoch to load (if load_episodes is not None) 842 load_searches : list, optional 843 a list of strings of hyperparameter search results to load 844 load_parameters : list, optional 845 a list of lists of string names of the parameters to load from the searches 846 round_to_binary : list, optional 847 a list of string names of the loaded parameters that should be rounded to the nearest power of two 848 load_strict : list, optional 849 a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from 850 the corresponding episode and differences in parameter name lists and 851 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for 852 every episode) 853 force : bool, default False 854 if `True` and an episode name is already taken, it will be overwritten (use with caution!) 855 suppress_name_check : bool, default False 856 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 857 why they are usually forbidden) 858 remove_saved_features : bool, default False 859 if `True`, the dataset will be deleted after training 860 """ 861 862 task = None 863 if load_searches is None: 864 load_searches = [None for _ in episode_names] 865 if load_episodes is None: 866 load_episodes = [None for _ in episode_names] 867 if parameters_updates is None: 868 parameters_updates = [None for _ in episode_names] 869 if load_parameters is None: 870 load_parameters = [None for _ in episode_names] 871 if load_epochs is None: 872 load_epochs = [None for _ in episode_names] 873 if load_strict is None: 874 load_strict = [True for _ in episode_names] 875 for ( 876 parameters_update, 877 episode_name, 878 load_episode, 879 load_epoch, 880 load_search, 881 load_parameters_list, 882 load_strict_value, 883 ) in zip( 884 parameters_updates, 885 episode_names, 886 load_episodes, 887 load_epochs, 888 load_searches, 889 load_parameters, 890 load_strict, 891 ): 892 task = self.run_episode( 893 episode_name, 894 load_episode, 895 parameters_update, 896 task, 897 load_epoch, 898 load_search, 899 load_parameters_list, 900 round_to_binary, 901 load_strict_value, 902 suppress_name_check=suppress_name_check, 903 force=force, 904 remove_saved_features=remove_saved_features, 905 ) 906 return task 907 908 def continue_episode( 909 self, 910 episode_name: str, 911 num_epochs: int = None, 912 task: TaskDispatcher = None, 913 n_seeds: int = 1, 914 remove_saved_features: bool = False, 915 device: str = "cuda", 916 num_cpus: int = None, 917 ) -> TaskDispatcher: 918 """ 919 Load an older episode and continue running from the latest checkpoint 920 921 All parameters as well as the model and optimizer state dictionaries are loaded from the episode. 922 923 Parameters 924 ---------- 925 episode_name : str 926 the name of the episode to continue 927 num_epochs : int, optional 928 the new number of epochs 929 task : TaskDispatcher, optional 930 a pre-existing task; if provided, the method will update the task instead of creating a new one 931 (this might save time, mainly on dataset loading) 932 result_average_interval : int, default 5 933 the metric are averaged over the last result_average_interval to be stored in the episodes meta file 934 and displayed by list_episodes() function (the full log is still always available) 935 n_seeds : int, default 1 936 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name::run_index`, e.g. 937 `test_episode::0` and `test_episode::1` 938 remove_saved_features : bool, default False 939 if `True`, pre-computed features will be deleted after the run 940 device : str, default "cuda" 941 the torch device to use 942 """ 943 944 runs = self._episodes().get_runs(episode_name) 945 for run in runs: 946 print(f"TRAINING {run}") 947 if num_epochs is None and not self._episode(run).unfinished(): 948 continue 949 parameters_update = { 950 "training": { 951 "num_epochs": num_epochs, 952 "device": device, 953 }, 954 "general": {"num_cpus": num_cpus}, 955 } 956 task, parameters = self._make_task_training( 957 run, 958 load_episode=run, 959 parameters_update=parameters_update, 960 continuing=True, 961 task=task, 962 ) 963 time_start = time.time() 964 logs = task.train() 965 time_end = time.time() 966 old_time = self._training_time(run) 967 if not np.isnan(old_time): 968 time_end += old_time 969 time_total = time_end - time_start 970 hours = int(time_total // 3600) 971 time_total -= hours * 3600 972 minutes = int(time_total // 60) 973 time_total -= minutes * 60 974 seconds = int(time_total) 975 training_time = f"{hours}:{minutes:02}:{seconds:02}" 976 else: 977 training_time = np.nan 978 self._save_episode( 979 run, 980 parameters, 981 task.behaviors_dict(), 982 suppress_validation=True, 983 training_time=training_time, 984 norm_stats=task.get_normalization_stats(), 985 ) 986 self._update_episode_results(run, logs) 987 print("\n") 988 if len(runs) < n_seeds: 989 for i in range(len(runs), n_seeds): 990 self.run_episode( 991 f"{episode_name}::{i}", 992 parameters_update=self._episodes().load_parameters(runs[0]), 993 task=task, 994 suppress_name_check=True, 995 ) 996 if remove_saved_features: 997 self._remove_stores(parameters) 998 return task 999 1000 def run_default_hyperparameter_search( 1001 self, 1002 search_name: str, 1003 model_name: str = None, 1004 metric: str = "f1", 1005 best_n: int = 3, 1006 direction: str = "maximize", 1007 load_episode: str = None, 1008 load_epoch: int = None, 1009 load_strict: bool = True, 1010 prune: bool = True, 1011 force: bool = False, 1012 remove_saved_features: bool = False, 1013 overlap: float = 0, 1014 num_epochs: int = 50, 1015 test_frac: float = 0, 1016 n_trials=150, 1017 device: str = None, 1018 ): 1019 """ 1020 Run an optuna hyperparameter search with default parameters for a model 1021 1022 For the vast majority of cases, optimizing the default parameters should be enough. 1023 Check out `dlc2action.options.model_hyperparameters` for the lists of parameters. 1024 There are also options to set overlap, test fraction and number of epochs parameters for the search without 1025 modifying the project config files. However, if you want something more complex, look into 1026 `Project.run_hyperparameter_search`. 1027 1028 The task parameters are read from the config files and updated with the parameters_update dictionary. 1029 The model can be either initialized from scratch or loaded from a previously run episode. 1030 For each trial, the objective metric is averaged over a few best epochs. 1031 1032 Parameters 1033 ---------- 1034 search_name : str 1035 the name of the search to store it in the meta files and load in run_episode 1036 model_name : str, optional 1037 the name of the model (by default loaded from the project settings, see `project.help('models')` for options) 1038 metric : str, default f1 1039 the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to 1040 `"none"` in the config files, it will be reset to `"macro"` for the search; see `project.help('metrics')` for options 1041 n_trials : int, default 20 1042 the number of optimization trials to run 1043 best_n : int, default 1 1044 the number of epochs to average the metric; if 0, the last value is taken 1045 parameters_update : dict, optional 1046 the parameters update dictionary 1047 direction : {'maximize', 'minimize'} 1048 optimization direction 1049 load_episode : str, optional 1050 the name of the episode to load the model from 1051 load_epoch : int, optional 1052 the epoch to load the model from (if not provided, the last checkpoint is used) 1053 prune : bool, default False 1054 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1055 (with optuna HyperBand pruner) 1056 force : bool, default False 1057 if `True`, existing searches with the same name will be overwritten 1058 remove_saved_features : bool, default False 1059 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1060 device : str, optional 1061 cuda:{i} or cpu, if not given it is read from the default parameters 1062 1063 Returns 1064 ------- 1065 dict 1066 a dictionary of best parameters 1067 """ 1068 1069 if model_name is None: 1070 model_name = self._read_parameters()["general"]["model_name"] 1071 if model_name not in options.model_hyperparameters: 1072 raise ValueError( 1073 f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()" 1074 ) 1075 pars = { 1076 "general": { 1077 "overlap": overlap, 1078 "model_name": model_name, 1079 "metric_functions": {metric}, 1080 }, 1081 "training": {"num_epochs": num_epochs}, 1082 } 1083 if test_frac is not None: 1084 pars["training"]["test_frac"] = test_frac 1085 if not metric.split("_")[-1].isnumeric(): 1086 project_pars = self._read_parameters() 1087 if project_pars["metrics"][metric].get("average") == "none": 1088 pars["metrics"] = {metric: {"average": "macro"}} 1089 if device is not None: 1090 pars["training"]["device"] = device 1091 return self.run_hyperparameter_search( 1092 search_name=search_name, 1093 search_space=options.model_hyperparameters[model_name], 1094 metric=metric, 1095 n_trials=n_trials, 1096 best_n=best_n, 1097 parameters_update=pars, 1098 direction=direction, 1099 load_episode=load_episode, 1100 load_epoch=load_epoch, 1101 load_strict=load_strict, 1102 prune=prune, 1103 force=force, 1104 remove_saved_features=remove_saved_features, 1105 ) 1106 1107 def run_hyperparameter_search( 1108 self, 1109 search_name: str, 1110 search_space: Dict, 1111 metric: str = "f1", 1112 n_trials: int = 20, 1113 best_n: int = 1, 1114 parameters_update: Dict = None, 1115 direction: str = "maximize", 1116 load_episode: str = None, 1117 load_epoch: int = None, 1118 load_strict: bool = True, 1119 prune: bool = False, 1120 force: bool = False, 1121 remove_saved_features: bool = False, 1122 ) -> Dict: 1123 """ 1124 Run an optuna hyperparameter search 1125 1126 For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`. 1127 1128 To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is 1129 a dictionary where keys are model names and values are default search spaces. 1130 1131 The task parameters are read from the config files and updated with the parameters_update dictionary. 1132 The model can be either initialized from scratch or loaded from a previously run episode. 1133 For each trial, the objective metric is averaged over a few best epochs. 1134 1135 Parameters 1136 ---------- 1137 search_name : str 1138 the name of the search to store it in the meta files and load in run_episode 1139 search_space : dict 1140 a dictionary representing the search space; of this general structure: 1141 {'group/param_name': ('float/int/float_log/int_log', start, end), 1142 'group/param_name': ('categorical', [choices])}, e.g. 1143 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 1144 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}; 1145 metric : str, default f1 1146 the metric to maximize/minimize (see direction) 1147 n_trials : int, default 20 1148 the number of optimization trials to run 1149 best_n : int, default 1 1150 the number of epochs to average the metric; if 0, the last value is taken 1151 parameters_update : dict, optional 1152 the parameters update dictionary 1153 direction : {'maximize', 'minimize'} 1154 optimization direction 1155 load_episode : str, optional 1156 the name of the episode to load the model from 1157 load_epoch : int, optional 1158 the epoch to load the model from (if not provided, the last checkpoint is used) 1159 prune : bool, default False 1160 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1161 (with optuna HyperBand pruner) 1162 force : bool, default False 1163 if `True`, existing searches with the same name will be overwritten 1164 remove_saved_features : bool, default False 1165 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1166 1167 Returns 1168 ------- 1169 dict 1170 a dictionary of best parameters 1171 """ 1172 1173 self._check_search_validity(search_name, force=force) 1174 print(f"SEARCH {search_name}") 1175 self.remove_episode(f"_{search_name}") 1176 if parameters_update is None: 1177 parameters_update = {} 1178 parameters_update = self._update( 1179 parameters_update, {"general": {"metric_functions": {metric}}} 1180 ) 1181 parameters = self._make_parameters( 1182 f"_{search_name}", 1183 load_episode, 1184 parameters_update, 1185 parameters_update_second={"training": {"model_save_path": None}}, 1186 load_epoch=load_epoch, 1187 load_strict=load_strict, 1188 ) 1189 task = None 1190 1191 if prune: 1192 pruner = optuna.pruners.HyperbandPruner() 1193 else: 1194 pruner = optuna.pruners.NopPruner() 1195 study = optuna.create_study(direction=direction, pruner=pruner) 1196 runner = _Runner( 1197 search_space=search_space, 1198 load_episode=load_episode, 1199 load_epoch=load_epoch, 1200 metric=metric, 1201 average=best_n, 1202 task=task, 1203 remove_saved_features=remove_saved_features, 1204 project=self, 1205 search_name=search_name, 1206 ) 1207 study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials) 1208 search_path = self._search_path(search_name) 1209 os.mkdir(search_path) 1210 fig = optuna.visualization.plot_contour(study) 1211 plotly.offline.plot( 1212 fig, filename=os.path.join(search_path, f"{search_name}_contour.html") 1213 ) 1214 fig = optuna.visualization.plot_param_importances(study) 1215 plotly.offline.plot( 1216 fig, filename=os.path.join(search_path, f"{search_name}_importances.html") 1217 ) 1218 best_params = study.best_params 1219 best_value = study.best_value 1220 self._save_search( 1221 search_name, 1222 parameters, 1223 n_trials, 1224 best_params, 1225 best_value, 1226 metric, 1227 search_space, 1228 ) 1229 self.remove_episode(f"_{search_name}") 1230 runner.clean() 1231 print(f"best parameters: {best_params}") 1232 print("\n") 1233 return best_params 1234 1235 def run_prediction( 1236 self, 1237 prediction_name: str, 1238 episode_names: List, 1239 load_epochs: List = None, 1240 parameters_update: Dict = None, 1241 augment_n: int = 10, 1242 data_path: str = None, 1243 mode: str = "all", 1244 file_paths: Set = None, 1245 remove_saved_features: bool = False, 1246 submission: bool = False, 1247 frame_number_map_file: str = None, 1248 force: bool = False, 1249 embedding: bool = False, 1250 ) -> None: 1251 """ 1252 Load models from previously run episodes to generate a prediction 1253 1254 The probabilities predicted by the models are averaged. 1255 Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder 1256 under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level 1257 keys are the video ids, the second-level keys are the clip ids (like individual names) and the values 1258 are the prediction arrays. 1259 1260 Parameters 1261 ---------- 1262 prediction_name : str 1263 the name of the prediction 1264 episode_names : list 1265 a list of string episode names to load the models from 1266 load_epochs : list, optional 1267 a list of integer epoch indices to load the model from; if None, the last ones are used 1268 parameters_update : dict, optional 1269 a dictionary of parameter updates 1270 augment_n : int, default 10 1271 the number of augmentations to average over 1272 data_path : str, optional 1273 the data path to run the prediction for 1274 mode : {'all', 'test', 'val', 'train'} 1275 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1276 file_paths : set, optional 1277 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1278 for 1279 remove_saved_features : bool, default False 1280 if `True`, pre-computed features will be deleted 1281 submission : bool, default False 1282 if `True`, a MABe-22 style submission file is generated 1283 frame_number_map_file : str, optional 1284 path to the frame number map file 1285 force : bool, default False 1286 if `True`, existing prediction with this name will be overwritten 1287 """ 1288 1289 self._check_prediction_validity(prediction_name, force=force) 1290 print(f"PREDICTION {prediction_name}") 1291 if submission: 1292 task = ... 1293 # TODO: add submission option to _make_prediction 1294 predicted = task.generate_submission( 1295 frame_number_map_file=frame_number_map_file, 1296 dataset=mode, 1297 augment_n=augment_n, 1298 ) 1299 folder = os.path.join( 1300 self.project_path, 1301 "results", 1302 "predictions", 1303 f"{prediction_name}", 1304 ) 1305 filename = os.path.join(folder, f"{prediction_name}.npy") 1306 np.save(filename, predicted, allow_pickle=True) 1307 else: 1308 try: 1309 ( 1310 task, 1311 parameters, 1312 mode, 1313 prediction, 1314 inference_time, 1315 ) = self._make_prediction( 1316 prediction_name, 1317 episode_names, 1318 load_epochs, 1319 parameters_update, 1320 data_path, 1321 file_paths, 1322 mode, 1323 augment_n, 1324 evaluate=False, 1325 embedding=embedding, 1326 ) 1327 predicted = task.dataset(mode).generate_full_length_prediction( 1328 prediction 1329 ) 1330 except ValueError: 1331 ( 1332 task, 1333 parameters, 1334 mode, 1335 predicted, 1336 inference_time, 1337 ) = self._aggregate_predictions( 1338 prediction_name, 1339 episode_names, 1340 load_epochs, 1341 parameters_update, 1342 data_path, 1343 file_paths, 1344 mode, 1345 augment_n, 1346 evaluate=False, 1347 embedding=embedding, 1348 ) 1349 folder = self.prediction_path(prediction_name) 1350 os.mkdir(folder) 1351 for video_id, prediction in predicted.items(): 1352 with open( 1353 os.path.join( 1354 folder, video_id + f"_{prediction_name}_prediction.pickle" 1355 ), 1356 "wb", 1357 ) as f: 1358 prediction["min_frames"], prediction["max_frames"] = task.dataset( 1359 mode 1360 ).get_min_max_frames(video_id) 1361 behavior_indices = sorted( 1362 [key for key in task.behaviors_dict() if key != -100] 1363 ) 1364 prediction["behaviors"] = [ 1365 task.behaviors_dict()[key] for key in behavior_indices 1366 ] 1367 pickle.dump(prediction, f) 1368 if remove_saved_features: 1369 self._remove_stores(parameters) 1370 self._save_prediction( 1371 prediction_name, 1372 parameters, 1373 task.behaviors_dict(), 1374 embedding, 1375 inference_time, 1376 ) 1377 print("\n") 1378 1379 def evaluate_prediction( 1380 self, 1381 prediction_name: str, 1382 parameters_update: Dict = None, 1383 data_path: str = None, 1384 file_paths: Set = None, 1385 mode: str = None, 1386 remove_saved_features: bool = False, 1387 ) -> Tuple[float, dict]: 1388 1389 with open( 1390 os.path.join( 1391 self.project_path, "results", "predictions", f"{prediction_name}.pickle" 1392 ), 1393 "rb", 1394 ) as f: 1395 prediction = pickle.load(f) 1396 if parameters_update is None: 1397 parameters_update = {} 1398 parameters_update = self._update( 1399 self._predictions().load_parameters(prediction_name), parameters_update 1400 ) 1401 parameters_update.pop("model") 1402 task, parameters, mode = self._make_task_prediction( 1403 "_", 1404 load_episode=None, 1405 parameters_update=parameters_update, 1406 data_path=data_path, 1407 file_paths=file_paths, 1408 mode=mode, 1409 ) 1410 results = task.evaluate_prediction(prediction, data=mode) 1411 if remove_saved_features: 1412 self._remove_stores(parameters) 1413 print("\n") 1414 return results 1415 1416 def evaluate( 1417 self, 1418 episode_names: List, 1419 load_epochs: List = None, 1420 augment_n: int = 0, 1421 data_path: str = None, 1422 file_paths: Set = None, 1423 mode: str = None, 1424 parameters_update: Dict = None, 1425 multiple_episode_policy: str = "average", 1426 remove_saved_features: bool = False, 1427 skip_updating_meta: bool = True, 1428 ) -> Dict: 1429 """ 1430 Load one or several models from previously run episodes to make an evaluation 1431 1432 By default it will run on the test (or validation, if there is no test) subset of the project dataset. 1433 1434 Parameters 1435 ---------- 1436 episode_names : list 1437 a list of string episode names to load the models from 1438 load_epochs : list, optional 1439 a list of integer epoch indices to load the model from; if None, the last ones are used 1440 augment_n : int, default 0 1441 the number of augmentations to average over 1442 data_path : str, optional 1443 the data path to run the prediction for 1444 file_paths : set, optional 1445 a set of files to run the prediction for 1446 mode : {'test', 'val', 'train', 'all'} 1447 the subset of the data to make the prediction for (forced to 'all' if data_path is not None; 1448 by default 'test' if test subset is not empty and 'val' otherwise) 1449 parameters_update : dict, optional 1450 a dictionary with parameter updates (cannot change model parameters) 1451 remove_saved_features : bool, default False 1452 if `True`, the dataset will be deleted 1453 1454 Returns 1455 ------- 1456 metric : dict 1457 a dictionary of average values of metric functions 1458 """ 1459 1460 names = [] 1461 for episode_name in episode_names: 1462 names += self._episodes().get_runs(episode_name) 1463 if len(set(episode_names)) == 1: 1464 print(f"EVALUATION {episode_names[0]}") 1465 else: 1466 print(f"EVALUATION {episode_names}") 1467 if len(names) > 1: 1468 evaluate = True 1469 else: 1470 evaluate = False 1471 if multiple_episode_policy == "average": 1472 try: 1473 ( 1474 task, 1475 parameters, 1476 mode, 1477 prediction, 1478 inference_time, 1479 ) = self._make_prediction( 1480 "_", 1481 episode_names, 1482 load_epochs, 1483 parameters_update, 1484 mode=mode, 1485 data_path=data_path, 1486 file_paths=file_paths, 1487 augment_n=augment_n, 1488 evaluate=evaluate, 1489 ) 1490 except: 1491 ( 1492 task, 1493 parameters, 1494 mode, 1495 prediction, 1496 inference_time, 1497 ) = self._aggregate_predictions( 1498 "_", 1499 episode_names, 1500 load_epochs, 1501 parameters_update, 1502 mode=mode, 1503 data_path=data_path, 1504 file_paths=file_paths, 1505 augment_n=augment_n, 1506 evaluate=evaluate, 1507 ) 1508 print("AGGREGATED:") 1509 _, results = task.evaluate_prediction(prediction, data=mode) 1510 if len(names) == 1 and mode == "val" and not skip_updating_meta: 1511 self._update_episode_metrics(names[0], results) 1512 elif multiple_episode_policy == "statistics": 1513 values = defaultdict(lambda: []) 1514 task = None 1515 for name in names: 1516 ( 1517 task, 1518 parameters, 1519 mode, 1520 prediction, 1521 inference_time, 1522 ) = self._make_prediction( 1523 "_", 1524 [name], 1525 load_epochs, 1526 parameters_update, 1527 mode=mode, 1528 data_path=data_path, 1529 file_paths=file_paths, 1530 augment_n=augment_n, 1531 evaluate=evaluate, 1532 task=task, 1533 ) 1534 _, metrics = task.evaluate_prediction(prediction, data=mode) 1535 for name, value in metrics.items(): 1536 values[name].append(value) 1537 if mode == "val" and not skip_updating_meta: 1538 self._update_episode_metrics(name, metrics) 1539 results = defaultdict(lambda: {}) 1540 mean_string = "" 1541 std_string = "" 1542 for key, value_list in values.items(): 1543 results[key]["mean"] = np.mean(value_list) 1544 results[key]["std"] = np.std(value_list) 1545 mean_string += f"{key} {np.mean(value_list):.3f}, " 1546 std_string += f"{key} {np.std(value_list):.3f}, " 1547 print("MEAN:") 1548 print(mean_string) 1549 print("STD:") 1550 print(std_string) 1551 else: 1552 raise ValueError( 1553 f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose " 1554 f"from ['average', 'statistics']" 1555 ) 1556 if len(names) > 0 and remove_saved_features: 1557 self._remove_stores(parameters) 1558 print(f"Inference time: {inference_time}") 1559 print("\n") 1560 return results 1561 1562 def _generate_similarity_score( 1563 self, 1564 prediction_name: str, 1565 target_video_id: str, 1566 target_clip: str, 1567 target_start: int, 1568 target_end: int, 1569 ) -> Dict: 1570 with open( 1571 os.path.join( 1572 self.project_path, 1573 "results", 1574 "predictions", 1575 f"{prediction_name}.pickle", 1576 ), 1577 "rb", 1578 ) as f: 1579 prediction = pickle.load(f) 1580 target = prediction[target_video_id][target_clip][:, target_start:target_end] 1581 score_dict = defaultdict(lambda: {}) 1582 for video_id in prediction: 1583 for clip_id in prediction[video_id]: 1584 score_dict[video_id][clip_id] = torch.cdist( 1585 target.T, prediction[video_id][score_dict].T 1586 ).min(0) 1587 return score_dict 1588 1589 def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict: 1590 interval_address = {} 1591 interval_value = {} 1592 s = 0 1593 n = 0 1594 for video_id, video_dict in score_dict.items(): 1595 for clip_id, value in video_dict.items(): 1596 s += value.mean() 1597 n += 1 1598 mean_value = s / n 1599 alpha = 1.75 1600 for it in range(10): 1601 id = 0 1602 interval_address = {} 1603 interval_value = {} 1604 for video_id, video_dict in score_dict.items(): 1605 for clip_id, value in video_dict.items(): 1606 res_indices_start, res_indices_end = apply_threshold( 1607 value, 1608 threshold=(2 - alpha * (0.9**it)) * mean_value, 1609 low=True, 1610 error_mask=None, 1611 min_frames=min_length, 1612 smooth_interval=0, 1613 ) 1614 for start, end in zip(res_indices_start, res_indices_end): 1615 interval_address[id] = [video_id, clip_id, start, end] 1616 interval_value[id] = score_dict[video_id][clip_id][ 1617 start:end 1618 ].mean() 1619 id += 1 1620 if len(interval_address) >= n_intervals: 1621 break 1622 if len(interval_address) < n_intervals: 1623 warnings.warn( 1624 f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals" 1625 ) 1626 sorted_intervals = sorted( 1627 interval_value.items(), key=lambda x: x[1], reverse=True 1628 ) 1629 output_intervals = [ 1630 interval_address[x[0]] 1631 for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)] 1632 ] 1633 output = defaultdict(lambda: []) 1634 for video_id, clip_id, start, end in output_intervals: 1635 output[video_id].append([start, end, clip_id]) 1636 return output 1637 1638 def list_episodes( 1639 self, 1640 episode_names: List = None, 1641 value_filter: str = "", 1642 display_parameters: List = None, 1643 print_results: bool = True, 1644 ) -> pd.DataFrame: 1645 """ 1646 Get a filtered pandas dataframe with episode metadata 1647 1648 Parameters 1649 ---------- 1650 episode_names : list 1651 a list of strings of episode names 1652 value_filter : str 1653 a string of filters to apply; of this general structure: 1654 'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g. 1655 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' 1656 display_parameters : list 1657 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1658 print_results : bool, default True 1659 if True, the result will be printed to standard output 1660 1661 Returns 1662 ------- 1663 pd.DataFrame 1664 the filtered dataframe 1665 """ 1666 1667 episodes = self._episodes().list_episodes( 1668 episode_names, value_filter, display_parameters 1669 ) 1670 if print_results: 1671 print("TRAINING EPISODES") 1672 print(episodes) 1673 print("\n") 1674 return episodes 1675 1676 def list_predictions( 1677 self, 1678 episode_names: List = None, 1679 value_filter: str = "", 1680 display_parameters: List = None, 1681 print_results: bool = True, 1682 ) -> pd.DataFrame: 1683 """ 1684 Get a filtered pandas dataframe with prediction metadata 1685 1686 Parameters 1687 ---------- 1688 episode_names : list 1689 a list of strings of episode names 1690 value_filter : str 1691 a string of filters to apply; of this general structure: 1692 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 1693 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 1694 display_parameters : list 1695 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1696 print_results : bool, default True 1697 if True, the result will be printed to standard output 1698 1699 Returns 1700 ------- 1701 pd.DataFrame 1702 the filtered dataframe 1703 """ 1704 1705 predictions = self._predictions().list_episodes( 1706 episode_names, value_filter, display_parameters 1707 ) 1708 if print_results: 1709 print("PREDICTIONS") 1710 print(predictions) 1711 print("\n") 1712 return predictions 1713 1714 def list_searches( 1715 self, 1716 search_names: List = None, 1717 value_filter: str = "", 1718 display_parameters: List = None, 1719 print_results: bool = True, 1720 ) -> pd.DataFrame: 1721 """ 1722 Get a filtered pandas dataframe with hyperparameter search metadata 1723 1724 Parameters 1725 ---------- 1726 search_names : list 1727 a list of strings of search names 1728 value_filter : str 1729 a string of filters to apply; of this general structure: 1730 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 1731 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 1732 display_parameters : list 1733 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1734 print_results : bool, default True 1735 if True, the result will be printed to standard output 1736 1737 Returns 1738 ------- 1739 pd.DataFrame 1740 the filtered dataframe 1741 """ 1742 1743 searches = self._searches().list_episodes( 1744 search_names, value_filter, display_parameters 1745 ) 1746 if print_results: 1747 print("SEARCHES") 1748 print(searches) 1749 print("\n") 1750 return searches 1751 1752 def get_best_parameters( 1753 self, 1754 search_name: str, 1755 round_to_binary: List = None, 1756 ): 1757 params, model = self._searches().get_best_params( 1758 search_name, round_to_binary=round_to_binary 1759 ) 1760 params = self._update(params, {"general": {"model_name": model}}) 1761 return params 1762 1763 def list_best_parameters( 1764 self, search_name: str, print_results: bool = True 1765 ) -> Dict: 1766 """ 1767 Get the raw dictionary of best parameters found by a search 1768 1769 Parameters 1770 ---------- 1771 search_name : str 1772 the name of the search 1773 print_results : bool, default True 1774 if True, the result will be printed to standard output 1775 1776 Returns 1777 ------- 1778 best_params : dict 1779 a dictionary of the best parameters where the keys are in '{group}/{name}' format 1780 """ 1781 1782 params = self._searches().get_best_params_raw(search_name) 1783 if print_results: 1784 print(f"SEARCH RESULTS {search_name}") 1785 for k, v in params.items(): 1786 print(f"{k}: {v}") 1787 print("\n") 1788 return params 1789 1790 def plot_episodes( 1791 self, 1792 episode_names: List, 1793 metrics: List, 1794 modes: List = None, 1795 title: str = None, 1796 episode_labels: List = None, 1797 save_path: str = None, 1798 add_hlines: List = None, 1799 epoch_limits: List = None, 1800 colors: List = None, 1801 add_highpoint_hlines: bool = False, 1802 ) -> None: 1803 """ 1804 Plot episode training curves 1805 1806 Parameters 1807 ---------- 1808 episode_names : list 1809 a list of episode names to plot; to plot to episodes in one line combine them in a list 1810 (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) 1811 metrics : list 1812 a list of metric to plot 1813 modes : list, optional 1814 a list of modes to plot ('train' and/or 'val'; `['val']` by default) 1815 title : str, optional 1816 title for the plot 1817 episode_labels : list, optional 1818 a list of strings used to label the curves (has to be the same length as episode_names) 1819 save_path : str, optional 1820 the path to save the resulting plot 1821 add_hlines : list, optional 1822 a list of float values (or (value, label) tuples) to mark with horizontal lines 1823 colors: list, optional 1824 a list of matplotlib colors 1825 add_highpoint_hlines : bool, default False 1826 if `True`, horizontal lines will be added at the highest value of each episode 1827 """ 1828 1829 if modes is None: 1830 modes = ["val"] 1831 if add_hlines is None: 1832 add_hlines = [] 1833 logs = [] 1834 epochs = [] 1835 labels = [] 1836 if episode_labels is not None: 1837 assert len(episode_labels) == len(episode_names) 1838 for name_i, name in enumerate(episode_names): 1839 log_params = product(metrics, modes) 1840 for metric, mode in log_params: 1841 if episode_labels is not None: 1842 label = episode_labels[name_i] 1843 else: 1844 label = deepcopy(name) 1845 if len(modes) != 1: 1846 label += f"_{mode}" 1847 if len(metrics) != 1: 1848 label += f"_{metric}" 1849 labels.append(label) 1850 if isinstance(name, Iterable) and not isinstance(name, str): 1851 epoch_list = defaultdict(lambda: []) 1852 multi_logs = defaultdict(lambda: []) 1853 for i, n in enumerate(name): 1854 runs = self._episodes().get_runs(n) 1855 if len(runs) > 1: 1856 for run in runs: 1857 index = run.split("::")[-1] 1858 if multi_logs[index] == []: 1859 if multi_logs["null"] is None: 1860 raise RuntimeError( 1861 "The run indices are not consistent across episodes!" 1862 ) 1863 else: 1864 multi_logs[index] += multi_logs["null"] 1865 multi_logs[index] += list( 1866 self._episode(run).get_metric_log(mode, metric) 1867 ) 1868 start = ( 1869 0 1870 if len(epoch_list[index]) == 0 1871 else epoch_list[index][-1] 1872 ) 1873 epoch_list[index] += [ 1874 x + start 1875 for x in self._episode(run).get_epoch_list(mode) 1876 ] 1877 multi_logs["null"] = None 1878 else: 1879 if len(multi_logs.keys()) > 1: 1880 raise RuntimeError( 1881 "Cannot plot a single-run episode after a multi-run episode!" 1882 ) 1883 multi_logs["null"] += list( 1884 self._episode(n).get_metric_log(mode, metric) 1885 ) 1886 start = ( 1887 0 1888 if len(epoch_list["null"]) == 0 1889 else epoch_list["null"][-1] 1890 ) 1891 epoch_list["null"] += [ 1892 x + start for x in self._episode(n).get_epoch_list(mode) 1893 ] 1894 if len(multi_logs.keys()) == 1: 1895 log = multi_logs["null"] 1896 epochs.append(epoch_list["null"]) 1897 else: 1898 log = tuple([v for k, v in multi_logs.items() if k != "null"]) 1899 epochs.append( 1900 tuple([v for k, v in epoch_list.items() if k != "null"]) 1901 ) 1902 else: 1903 runs = self._episodes().get_runs(name) 1904 if len(runs) > 1: 1905 log = [] 1906 for run in runs: 1907 tracked_metrics = self._episode(run).get_metrics() 1908 if metric in tracked_metrics: 1909 log.append( 1910 list( 1911 self._episode(run).get_metric_log(mode, metric) 1912 ) 1913 ) 1914 else: 1915 relevant = [] 1916 for m in tracked_metrics: 1917 m_split = m.split("_") 1918 if ( 1919 "_".join(m_split[:-1]) == metric 1920 and m_split[-1].isnumeric() 1921 ): 1922 relevant.append(m) 1923 if len(relevant) == 0: 1924 raise ValueError( 1925 f"The {metric} metric was not tracked at {run}" 1926 ) 1927 arr = 0 1928 for m in relevant: 1929 arr += self._episode(run).get_metric_log(mode, m) 1930 arr /= len(relevant) 1931 log.append(list(arr)) 1932 log = tuple(log) 1933 epochs.append( 1934 tuple( 1935 [ 1936 self._episode(run).get_epoch_list(mode) 1937 for run in runs 1938 ] 1939 ) 1940 ) 1941 else: 1942 tracked_metrics = self._episode(name).get_metrics() 1943 if metric in tracked_metrics: 1944 log = list(self._episode(name).get_metric_log(mode, metric)) 1945 else: 1946 relevant = [] 1947 for m in tracked_metrics: 1948 m_split = m.split("_") 1949 if ( 1950 "_".join(m_split[:-1]) == metric 1951 and m_split[-1].isnumeric() 1952 ): 1953 relevant.append(m) 1954 if len(relevant) == 0: 1955 raise ValueError( 1956 f"The {metric} metric was not tracked at {name}" 1957 ) 1958 arr = 0 1959 for m in relevant: 1960 arr += self._episode(name).get_metric_log(mode, m) 1961 arr /= len(relevant) 1962 log = list(arr) 1963 epochs.append(self._episode(name).get_epoch_list(mode)) 1964 logs.append(log) 1965 # if episode_labels is not None: 1966 # print(f'{len(episode_labels)=}, {len(logs)=}') 1967 # if len(episode_labels) != len(logs): 1968 1969 # raise ValueError( 1970 # f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of " 1971 # f"curves ({len(logs)})!" 1972 # ) 1973 # else: 1974 # labels = episode_labels 1975 if colors is None: 1976 colors = cm.rainbow(np.linspace(0, 1, len(logs))) 1977 if len(colors) != len(logs): 1978 raise ValueError( 1979 "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!" 1980 ) 1981 plt.figure() 1982 length = 0 1983 for log, label, color, epoch_list in zip(logs, labels, colors, epochs): 1984 if type(log) is list: 1985 if len(log) > length: 1986 length = len(log) 1987 plt.plot( 1988 epoch_list, 1989 log, 1990 label=label, 1991 color=color, 1992 ) 1993 if add_highpoint_hlines: 1994 plt.axhline(np.max(log), linestyle="dashed", color=color) 1995 else: 1996 for l, xx in zip(log, epoch_list): 1997 if len(l) > length: 1998 length = len(l) 1999 plt.plot( 2000 xx, 2001 l, 2002 color=color, 2003 alpha=0.2, 2004 ) 2005 if not all([len(x) == len(log[0]) for x in log]): 2006 warnings.warn( 2007 f"Got logs with unequal lengths in parallel runs for {label}" 2008 ) 2009 log = list(log) 2010 epoch_list = list(epoch_list) 2011 for i, x in enumerate(epoch_list): 2012 to_remove = [] 2013 for j, y in enumerate(x[1:]): 2014 if y <= x[j - 1]: 2015 y_ind = x.index(y) 2016 to_remove += list(range(y_ind, j)) 2017 epoch_list[i] = [ 2018 y for j, y in enumerate(x) if j not in to_remove 2019 ] 2020 log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove] 2021 length = min([len(x) for x in log]) 2022 for i in range(len(log)): 2023 log[i] = log[i][:length] 2024 epoch_list[i] = epoch_list[i][:length] 2025 if not all([x == epoch_list[0] for x in epoch_list]): 2026 raise RuntimeError( 2027 f"Got different epoch indices in parallel runs for {label}" 2028 ) 2029 mean = np.array(log).mean(0) 2030 plt.plot( 2031 epoch_list[0], 2032 mean, 2033 label=label, 2034 color=color, 2035 ) 2036 if add_highpoint_hlines: 2037 plt.axhline(np.max(mean), linestyle="dashed", color=color) 2038 for x in add_hlines: 2039 label = None 2040 if isinstance(x, Iterable): 2041 x, label = x 2042 plt.axhline(x, label=label) 2043 plt.xlim((0, length)) 2044 2045 plt.legend() 2046 plt.xlabel("epochs") 2047 if len(metrics) == 1: 2048 plt.ylabel(metrics[0]) 2049 else: 2050 plt.ylabel("value") 2051 if title is None: 2052 if len(episode_names) == 1: 2053 title = episode_names[0] 2054 elif len(metrics) == 1: 2055 title = metrics[0] 2056 if epoch_limits is not None: 2057 plt.xlim(epoch_limits) 2058 if title is not None: 2059 plt.title(title) 2060 plt.show() 2061 if save_path is not None: 2062 plt.savefig(save_path) 2063 2064 def update_parameters( 2065 self, 2066 parameters_update: Dict = None, 2067 load_search: str = None, 2068 load_parameters: List = None, 2069 round_to_binary: List = None, 2070 ) -> None: 2071 """ 2072 Update the parameters in the project config files 2073 2074 Parameters 2075 ---------- 2076 parameters_update : dict, optional 2077 a dictionary of parameter updates 2078 load_search : str, optional 2079 the name of hyperparameter search results to load to config 2080 load_parameters : list, optional 2081 a list of lists of string names of the parameters to load from the searches 2082 round_to_binary : list, optional 2083 a list of string names of the loaded parameters that should be rounded to the nearest power of two 2084 """ 2085 2086 keys = [ 2087 "general", 2088 "losses", 2089 "metrics", 2090 "ssl", 2091 "training", 2092 "data", 2093 ] 2094 parameters = self._read_parameters(catch_blanks=False) 2095 if parameters_update is not None: 2096 if "model" in parameters_update: 2097 model_params = parameters_update.pop("model") 2098 else: 2099 model_params = None 2100 if "features" in parameters_update: 2101 feat_params = parameters_update.pop("features") 2102 else: 2103 feat_params = None 2104 if "augmentations" in parameters_update: 2105 aug_params = parameters_update.pop("augmentations") 2106 else: 2107 aug_params = None 2108 parameters = self._update(parameters, parameters_update) 2109 model_name = parameters["general"]["model_name"] 2110 parameters["model"] = self._open_yaml( 2111 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2112 ) 2113 if model_params is not None: 2114 parameters["model"] = self._update(parameters["model"], model_params) 2115 feat_name = parameters["general"]["feature_extraction"] 2116 parameters["features"] = self._open_yaml( 2117 os.path.join( 2118 self.project_path, "config", "features", f"{feat_name}.yaml" 2119 ) 2120 ) 2121 if feat_params is not None: 2122 parameters["features"] = self._update( 2123 parameters["features"], feat_params 2124 ) 2125 aug_name = options.extractor_to_transformer[ 2126 parameters["general"]["feature_extraction"] 2127 ] 2128 parameters["augmentations"] = self._open_yaml( 2129 os.path.join( 2130 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2131 ) 2132 ) 2133 if aug_params is not None: 2134 parameters["augmentations"] = self._update( 2135 parameters["augmentations"], aug_params 2136 ) 2137 if load_search is not None: 2138 parameters_update, model_name = self._searches().get_best_params( 2139 load_search, load_parameters, round_to_binary 2140 ) 2141 parameters["general"]["model_name"] = model_name 2142 parameters["model"] = self._open_yaml( 2143 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2144 ) 2145 parameters = self._update(parameters, parameters_update) 2146 for key in keys: 2147 with open( 2148 os.path.join(self.project_path, "config", f"{key}.yaml"), "w", encoding="utf-8" 2149 ) as f: 2150 YAML().dump(parameters[key], f) 2151 model_name = parameters["general"]["model_name"] 2152 model_path = os.path.join( 2153 self.project_path, "config", "model", f"{model_name}.yaml" 2154 ) 2155 with open(model_path, "w", encoding="utf-8") as f: 2156 YAML().dump(parameters["model"], f) 2157 features_name = parameters["general"]["feature_extraction"] 2158 features_path = os.path.join( 2159 self.project_path, "config", "features", f"{features_name}.yaml" 2160 ) 2161 with open(features_path, "w", encoding="utf-8") as f: 2162 YAML().dump(parameters["features"], f) 2163 aug_name = options.extractor_to_transformer[features_name] 2164 aug_path = os.path.join( 2165 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2166 ) 2167 with open(aug_path, "w", encoding="utf-8") as f: 2168 YAML().dump(parameters["augmentations"], f) 2169 2170 def get_summary( 2171 self, 2172 episode_names: list, 2173 method: str = "last", 2174 average: int = 1, 2175 metrics: List = None, 2176 ) -> Dict: 2177 """ 2178 Get a summary of episode statistics 2179 2180 If the episode has multiple runs, the statistics will be aggregated over all of them. 2181 2182 Parameters 2183 ---------- 2184 episode_name : str 2185 the name of the episode 2186 method : ["best", "last"] 2187 the method for choosing the epochs 2188 average : int, default 1 2189 the number of epochs to average over (for each run) 2190 metrics : list, optional 2191 a list of metrics 2192 2193 Returns 2194 ------- 2195 statistics : dict 2196 a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean 2197 and 'std' for the standard deviation 2198 """ 2199 2200 runs = [] 2201 for episode_name in episode_names: 2202 runs_ep = self._episodes().get_runs(episode_name) 2203 if len(runs_ep) == 0: 2204 raise RuntimeError( 2205 f"There is no {episode_name} episode in the project memory" 2206 ) 2207 runs += runs_ep 2208 if metrics is None: 2209 metrics = self._episode(runs[0]).get_metrics() 2210 2211 values = {m: [] for m in metrics} 2212 for run in runs: 2213 for m in metrics: 2214 log = self._episode(run).get_metric_log(mode="val", metric_name=m) 2215 if method == "best": 2216 log = sorted(log) 2217 values[m] += list(log[-average:]) 2218 elif method == "last": 2219 if len(log) == 0: 2220 episodes = self._episodes().data 2221 if average == 1 and ("results", m) in episodes.columns: 2222 values[m] += [episodes.loc[run, ("results", m)]] 2223 else: 2224 raise RuntimeError(f"Did not find {m} metric for {run} run") 2225 values[m] += list(log[-average:]) 2226 elif method.startswith("epoch"): 2227 epoch = int(method[5:]) - 1 2228 pars = self._episodes().load_parameters(run) 2229 step = int(pars["training"]["validation_interval"]) 2230 values[m] += [log[epoch // step]] 2231 else: 2232 raise ValueError( 2233 f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']" 2234 ) 2235 statistics = defaultdict(lambda: {}) 2236 for m, v in values.items(): 2237 statistics[m]["mean"] = np.mean(v) 2238 statistics[m]["std"] = np.std(v) 2239 print(f"SUMMARY {episode_names}") 2240 for m, v in statistics.items(): 2241 print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}') 2242 print("\n") 2243 return dict(statistics) 2244 2245 @staticmethod 2246 def remove_project(name: str, projects_path: str = None) -> None: 2247 """ 2248 Remove all project files and experiment records and results 2249 """ 2250 2251 if projects_path is None: 2252 projects_path = os.path.join(str(Path.home()), "DLC2Action") 2253 project_path = os.path.join(projects_path, name) 2254 if os.path.exists(project_path): 2255 shutil.rmtree(project_path) 2256 2257 def remove_saved_features( 2258 self, 2259 dataset_names: List = None, 2260 exceptions: List = None, 2261 remove_active: bool = False, 2262 ) -> None: 2263 """ 2264 Remove saved pre-computed dataset files 2265 2266 By default, all pre-computed features will be deleted. 2267 No essential information can get lost, storing them only saves time. Be careful with deleting datasets 2268 while training or inference is happening though. 2269 2270 Parameters 2271 ---------- 2272 dataset_names : list, optional 2273 a list of dataset names to delete (by default all names are added) 2274 exceptions : list, optional 2275 a list of dataset names to not be deleted 2276 remove_active : bool, default False 2277 if `False`, datasets used by unfinished episodes will not be deleted 2278 """ 2279 2280 print("Removing datasets...") 2281 if dataset_names is None: 2282 dataset_names = [] 2283 if exceptions is None: 2284 exceptions = [] 2285 if not remove_active: 2286 exceptions += self._episodes().get_active_datasets() 2287 dataset_path = os.path.join(self.project_path, "saved_datasets") 2288 if os.path.exists(dataset_path): 2289 if dataset_names == []: 2290 dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)]) 2291 2292 to_remove = [ 2293 x 2294 for x in dataset_names 2295 if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions 2296 ] 2297 if len(to_remove) > 2: 2298 to_remove = tqdm(to_remove) 2299 for dataset in to_remove: 2300 shutil.rmtree(os.path.join(dataset_path, dataset)) 2301 to_remove = [ 2302 f"{x}.pickle" 2303 for x in dataset_names 2304 if os.path.exists(os.path.join(dataset_path, f"{x}.pickle")) 2305 and x not in exceptions 2306 ] 2307 for dataset in to_remove: 2308 os.remove(os.path.join(dataset_path, dataset)) 2309 names = self._saved_datasets().dataset_names() 2310 self._saved_datasets().remove(names) 2311 print("\n") 2312 2313 def remove_extra_checkpoints( 2314 self, episode_names: List = None, exceptions: List = None 2315 ) -> None: 2316 """ 2317 Remove intermediate model checkpoint files (only leave the results of the last epoch) 2318 2319 By default, all intermediate checkpoints will be deleted. 2320 Files in the model folder that are not associated with any record in the meta files are also deleted. 2321 2322 Parameters 2323 ---------- 2324 episode_names : list, optional 2325 a list of episode names to clean (by default all names are added) 2326 exceptions : list, optional 2327 a list of episode names to not clean 2328 """ 2329 2330 model_path = os.path.join(self.project_path, "results", "model") 2331 try: 2332 all_names = self._episodes().data.index 2333 except: 2334 all_names = os.listdir(model_path) 2335 if episode_names is None: 2336 episode_names = all_names 2337 if exceptions is None: 2338 exceptions = [] 2339 to_remove = [x for x in episode_names if x not in exceptions] 2340 folders = os.listdir(model_path) 2341 for folder in folders: 2342 if folder not in all_names: 2343 shutil.rmtree(os.path.join(model_path, folder)) 2344 elif folder in to_remove: 2345 files = os.listdir(os.path.join(model_path, folder)) 2346 for file in sorted(files)[:-1]: 2347 os.remove(os.path.join(model_path, folder, file)) 2348 2349 def remove_search(self, search_name: str) -> None: 2350 """ 2351 Remove a hyperparameter search record 2352 2353 Parameters 2354 ---------- 2355 search_name : str 2356 the name of the search to remove 2357 """ 2358 2359 self._searches().remove_episode(search_name) 2360 graph_path = os.path.join(self.project_path, "results", "searches", search_name) 2361 if os.path.exists(graph_path): 2362 shutil.rmtree(graph_path) 2363 2364 def remove_prediction(self, prediction_name: str) -> None: 2365 """ 2366 Remove a prediction record 2367 2368 Parameters 2369 ---------- 2370 prediction_name : str 2371 the name of the prediction to remove 2372 """ 2373 2374 self._predictions().remove_episode(prediction_name) 2375 prediction_path = os.path.join( 2376 self.project_path, "results", "predictions", prediction_name 2377 ) 2378 if os.path.exists(prediction_path): 2379 shutil.rmtree(prediction_path) 2380 2381 def remove_episode(self, episode_name: str) -> None: 2382 """ 2383 Remove all model, logs and metafile records related to an episode 2384 2385 Parameters 2386 ---------- 2387 episode_name : str 2388 the name of the episode to remove 2389 """ 2390 2391 runs = self._episodes().get_runs(episode_name) 2392 runs.append(episode_name) 2393 for run in runs: 2394 self._episodes().remove_episode(run) 2395 model_path = os.path.join(self.project_path, "results", "model", run) 2396 if os.path.exists(model_path): 2397 shutil.rmtree(model_path) 2398 log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt") 2399 if os.path.exists(log_path): 2400 os.remove(log_path) 2401 2402 def prune_unfinished(self, exceptions: List = None) -> None: 2403 """ 2404 Remove all interrupted episodes 2405 2406 Remove all episodes that either don't have a log file or have less epochs in the log file than in 2407 the training parameters or have a model folder but not a record. Note that it can remove episodes that are 2408 currently running! 2409 2410 Parameters 2411 ---------- 2412 exceptions : list 2413 the episodes to keep even if they are interrupted 2414 2415 Returns 2416 ------- 2417 pruned : list 2418 a list of the episode names that were pruned 2419 """ 2420 2421 if exceptions is None: 2422 exceptions = [] 2423 unfinished = self._episodes().unfinished_episodes() 2424 unfinished = [x for x in unfinished if x not in exceptions] 2425 model_folders = os.listdir(os.path.join(self.project_path, "results", "model")) 2426 unfinished += [ 2427 x for x in model_folders if x not in self._episodes().list_episodes().index 2428 ] 2429 print(f"PRUNING {unfinished}") 2430 for episode_name in unfinished: 2431 self.remove_episode(episode_name) 2432 print(f"\n") 2433 return unfinished 2434 2435 def prediction_path(self, prediction_name: str) -> str: 2436 """ 2437 Get the path where prediction files are saved 2438 2439 Parameters 2440 ---------- 2441 prediction_name : str 2442 name of the prediction 2443 2444 Returns 2445 ------- 2446 prediction_path : str 2447 the file path 2448 """ 2449 2450 return os.path.join( 2451 self.project_path, "results", "predictions", f"{prediction_name}" 2452 ) 2453 2454 @classmethod 2455 def print_data_types(cls): 2456 print("DATA TYPES:") 2457 for key, value in cls.data_types().items(): 2458 print(f"{key}:") 2459 print(value.__doc__) 2460 2461 @classmethod 2462 def print_annotation_types(cls): 2463 print("ANNOTATION TYPES:") 2464 for key, value in cls.annotation_types().items(): 2465 print(f"{key}:") 2466 print(value.__doc__) 2467 2468 @staticmethod 2469 def data_types() -> List: 2470 """ 2471 Get available data types 2472 2473 Returns 2474 ------- 2475 list 2476 available data types 2477 """ 2478 2479 return options.input_stores 2480 2481 @staticmethod 2482 def annotation_types() -> List: 2483 """ 2484 Get available annotation types 2485 2486 Returns 2487 ------- 2488 list 2489 available annotation types 2490 """ 2491 2492 return options.annotation_stores 2493 2494 def _save_mask(self, file: Dict, mask_name: str): 2495 """ 2496 Save a mask file 2497 """ 2498 2499 if not os.path.exists(self._mask_path()): 2500 os.mkdir(self._mask_path()) 2501 with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f: 2502 pickle.dump(file, f) 2503 2504 def _load_mask(self, mask_name: str) -> Dict: 2505 """ 2506 Load a mask file 2507 """ 2508 2509 with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f: 2510 data = pickle.load(f) 2511 return data 2512 2513 def _thresholds(self) -> DecisionThresholds: 2514 """ 2515 Get the decision thresholds meta object 2516 """ 2517 2518 return DecisionThresholds(self._thresholds_path()) 2519 2520 def _episodes(self) -> SavedRuns: 2521 """ 2522 Get the episodes meta object 2523 2524 Returns 2525 ------- 2526 episodes : SavedRuns 2527 the episodes meta object 2528 """ 2529 2530 try: 2531 return SavedRuns(self._episodes_path(), self.project_path) 2532 except: 2533 self.load_metadata_backup() 2534 return SavedRuns(self._episodes_path(), self.project_path) 2535 2536 def _predictions(self) -> SavedRuns: 2537 """ 2538 Get the predictions meta object 2539 2540 Returns 2541 ------- 2542 predictions : SavedRuns 2543 the predictions meta object 2544 """ 2545 2546 try: 2547 return SavedRuns(self._predictions_path(), self.project_path) 2548 except: 2549 self.load_metadata_backup() 2550 return SavedRuns(self._predictions_path(), self.project_path) 2551 2552 def _saved_datasets(self) -> SavedStores: 2553 """ 2554 Get the datasets meta object 2555 2556 Returns 2557 ------- 2558 datasets : SavedStores 2559 the datasets meta object 2560 """ 2561 2562 try: 2563 return SavedStores(self._saved_datasets_path()) 2564 except: 2565 self.load_metadata_backup() 2566 return SavedStores(self._saved_datasets_path()) 2567 2568 def _prediction(self, name: str) -> Run: 2569 """ 2570 Get a prediction meta object 2571 2572 Parameters 2573 ---------- 2574 name : str 2575 episode name 2576 2577 Returns 2578 ------- 2579 prediction : Run 2580 the prediction meta object 2581 """ 2582 2583 try: 2584 return Run(name, self.project_path, meta_path=self._predictions_path()) 2585 except: 2586 self.load_metadata_backup() 2587 return Run(name, self.project_path, meta_path=self._predictions_path()) 2588 2589 def _episode(self, name: str) -> Run: 2590 """ 2591 Get an episode meta object 2592 2593 Parameters 2594 ---------- 2595 name : str 2596 episode name 2597 2598 Returns 2599 ------- 2600 episode : Run 2601 the episode meta object 2602 """ 2603 2604 try: 2605 return Run(name, self.project_path, meta_path=self._episodes_path()) 2606 except: 2607 self.load_metadata_backup() 2608 return Run(name, self.project_path, meta_path=self._episodes_path()) 2609 2610 def _searches(self) -> Searches: 2611 """ 2612 Get the hyperparameter search meta object 2613 2614 Returns 2615 ------- 2616 searches : Searches 2617 the searches meta object 2618 """ 2619 2620 try: 2621 return Searches(self._searches_path(), self.project_path) 2622 except: 2623 self.load_metadata_backup() 2624 return Searches(self._searches_path(), self.project_path) 2625 2626 def _update_configs(self) -> None: 2627 """ 2628 Update the project config files with newly added files and parameters 2629 """ 2630 2631 self.update_parameters({"data": {"data_path": self.data_path}}) 2632 folders = ["augmentations", "features", "model"] 2633 original_path = os.path.join( 2634 os.path.dirname(os.path.dirname(__file__)), "config" 2635 ) 2636 project_path = os.path.join(self.project_path, "config") 2637 filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")] 2638 for folder in folders: 2639 filenames += [ 2640 os.path.join(folder, x) 2641 for x in os.listdir(os.path.join(original_path, folder)) 2642 ] 2643 filenames.append(os.path.join("data", f"{self.data_type}.yaml")) 2644 if self.annotation_type != "none": 2645 filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml")) 2646 for file in filenames: 2647 filepath_original = os.path.join(original_path, file) 2648 if file.startswith("data") or file.startswith("annotation"): 2649 file = os.path.basename(file) 2650 filepath_project = os.path.join(project_path, file) 2651 if not os.path.exists(filepath_project): 2652 shutil.copy(filepath_original, filepath_project) 2653 else: 2654 original_pars = self._open_yaml(filepath_original) 2655 project_pars = self._open_yaml(filepath_project) 2656 to_remove = [] 2657 for key, value in project_pars.items(): 2658 if key not in original_pars: 2659 if key not in ["data_type", "annotation_type"]: 2660 to_remove.append(key) 2661 for key in to_remove: 2662 project_pars.pop(key) 2663 to_remove = [] 2664 for key, value in original_pars.items(): 2665 if key in project_pars: 2666 to_remove.append(key) 2667 for key in to_remove: 2668 original_pars.pop(key) 2669 project_pars = self._update(project_pars, original_pars) 2670 with open(filepath_project, "w", encoding="utf-8") as f: 2671 YAML().dump(project_pars, f) 2672 2673 def _update_project(self) -> None: 2674 """ 2675 Update project files with the current version 2676 """ 2677 2678 version_file = self._version_path() 2679 ok = True 2680 if not os.path.exists(version_file): 2681 ok = False 2682 else: 2683 with open(version_file) as f: 2684 project_version = f.read() 2685 if project_version < __version__: 2686 ok = False 2687 elif project_version > __version__: 2688 warnings.warn( 2689 f"The project expects a higher dlc2action version ({project_version}), please update!" 2690 ) 2691 if not ok: 2692 project_config_path = os.path.join(self.project_path, "config") 2693 config_path = os.path.join( 2694 os.path.dirname(os.path.dirname(__path__)), "config" 2695 ) 2696 episodes = self._episodes() 2697 folders = ["annotation", "augmentations", "data", "features", "model"] 2698 2699 project_annotation_configs = os.listdir( 2700 os.path.join(project_config_path, "annotation") 2701 ) 2702 annotation_configs = os.listdir(os.path.join(config_path, "annotation")) 2703 for ann_config in annotation_configs: 2704 if ann_config not in project_annotation_configs: 2705 shutil.copytree( 2706 os.path.join(config_path, "annotation", ann_config), 2707 os.path.join(project_config_path, "annotation", ann_config), 2708 dirs_exist_ok=True, 2709 ) 2710 else: 2711 project_pars = self._open_yaml( 2712 os.path.join(project_config_path, "annotation", ann_config) 2713 ) 2714 pars = self._open_yaml( 2715 os.path.join(config_path, "annotation", ann_config) 2716 ) 2717 new_keys = set(pars.keys()) - set(project_pars.keys()) 2718 for key in new_keys: 2719 project_pars[key] = pars[key] 2720 c = self._get_comment(pars.ca.items.get(key)) 2721 project_pars.yaml_add_eol_comment(c, key=key) 2722 episodes.update( 2723 condition=f"general/annotation_type::={ann_config}", 2724 update={f"data/{key}": pars[key]}, 2725 ) 2726 2727 def _initialize_project( 2728 self, 2729 data_type: str, 2730 annotation_type: str = None, 2731 data_path: str = None, 2732 annotation_path: str = None, 2733 copy: bool = True, 2734 ) -> None: 2735 """ 2736 Initialize a new project 2737 """ 2738 2739 if data_type not in self.data_types(): 2740 raise ValueError( 2741 f"The {data_type} data type is not available. " 2742 f"Please choose from {self.data_types()}" 2743 ) 2744 if annotation_type not in self.annotation_types(): 2745 raise ValueError( 2746 f"The {annotation_type} annotation type is not available. " 2747 f"Please choose from {self.annotation_types()}" 2748 ) 2749 os.mkdir(self.project_path) 2750 folders = ["results", "saved_datasets", "meta", "config"] 2751 for f in folders: 2752 os.mkdir(os.path.join(self.project_path, f)) 2753 results_subfolders = [ 2754 "model", 2755 "logs", 2756 "predictions", 2757 "splits", 2758 "searches", 2759 ] 2760 for sf in results_subfolders: 2761 os.mkdir(os.path.join(self.project_path, "results", sf)) 2762 if data_path is not None: 2763 if copy: 2764 os.mkdir(os.path.join(self.project_path, "data")) 2765 shutil.copytree( 2766 data_path, 2767 os.path.join(self.project_path, "data"), 2768 dirs_exist_ok=True, 2769 ) 2770 data_path = os.path.join(self.project_path, "data") 2771 if annotation_path is not None: 2772 if copy: 2773 os.mkdir(os.path.join(self.project_path, "annotation")) 2774 shutil.copytree( 2775 annotation_path, 2776 os.path.join(self.project_path, "annotation"), 2777 dirs_exist_ok=True, 2778 ) 2779 annotation_path = os.path.join(self.project_path, "annotation") 2780 self._generate_config( 2781 data_type, 2782 annotation_type, 2783 data_path=data_path, 2784 annotation_path=annotation_path, 2785 ) 2786 self._generate_meta() 2787 2788 def _read_types(self) -> Tuple[str, str]: 2789 """ 2790 Get data type and annotation type from existing project files 2791 """ 2792 2793 config_path = os.path.join(self.project_path, "config", "general.yaml") 2794 with open(config_path) as f: 2795 pars = YAML().load(f) 2796 data_type = pars["data_type"] 2797 annotation_type = pars["annotation_type"] 2798 return annotation_type, data_type 2799 2800 def _read_paths(self) -> Tuple[str, str]: 2801 """ 2802 Get data type and annotation type from existing project files 2803 """ 2804 2805 config_path = os.path.join(self.project_path, "config", "data.yaml") 2806 with open(config_path) as f: 2807 pars = YAML().load(f) 2808 data_path = pars["data_path"] 2809 annotation_path = pars["annotation_path"] 2810 return annotation_path, data_path 2811 2812 def _generate_config( 2813 self, data_type: str, annotation_type: str, data_path: str, annotation_path: str 2814 ) -> None: 2815 """ 2816 Initialize the config files 2817 """ 2818 2819 default_path = os.path.join( 2820 os.path.dirname(os.path.dirname(__file__)), "config" 2821 ) 2822 config_path = os.path.join(self.project_path, "config") 2823 files = ["losses", "metrics", "ssl", "training"] 2824 for f in files: 2825 shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path) 2826 shutil.copytree( 2827 os.path.join(default_path, "model"), os.path.join(config_path, "model") 2828 ) 2829 shutil.copytree( 2830 os.path.join(default_path, "features"), 2831 os.path.join(config_path, "features"), 2832 ) 2833 shutil.copytree( 2834 os.path.join(default_path, "augmentations"), 2835 os.path.join(config_path, "augmentations"), 2836 ) 2837 yaml = YAML() 2838 data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml") 2839 if os.path.exists(data_param_path): 2840 with open(data_param_path, encoding="utf-8") as f: 2841 data_params = yaml.load(f) 2842 if data_params is None: 2843 data_params = {} 2844 if annotation_type is None: 2845 ann_params = {} 2846 else: 2847 ann_param_path = os.path.join( 2848 default_path, "annotation", f"{annotation_type}.yaml" 2849 ) 2850 if os.path.exists(ann_param_path): 2851 ann_params = self._open_yaml(ann_param_path) 2852 elif annotation_type == "none": 2853 ann_params = {} 2854 else: 2855 raise ValueError( 2856 f"The {annotation_type} data type is not available. " 2857 f"Please choose from {BehaviorDataset.annotation_types()}" 2858 ) 2859 if ann_params is None: 2860 ann_params = {} 2861 data_params = self._update(data_params, ann_params) 2862 data_params["data_path"] = data_path 2863 data_params["annotation_path"] = annotation_path 2864 with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f: 2865 yaml.dump(data_params, f) 2866 with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f: 2867 general_params = yaml.load(f) 2868 general_params["data_type"] = data_type 2869 general_params["annotation_type"] = annotation_type 2870 with open(os.path.join(config_path, "general.yaml"), "w", encoding="utf-8") as f: 2871 yaml.dump(general_params, f) 2872 2873 def _generate_meta(self) -> None: 2874 """ 2875 Initialize the meta files 2876 """ 2877 2878 config_file = os.path.join(self.project_path, "config") 2879 meta_fields = ["time"] 2880 columns = [("meta", field) for field in meta_fields] 2881 episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 2882 episodes.to_pickle(self._episodes_path()) 2883 meta_fields = ["time", "objective"] 2884 result_fields = ["best_params", "best_value"] 2885 columns = [("meta", field) for field in meta_fields] + [ 2886 ("results", field) for field in result_fields 2887 ] 2888 searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 2889 searches.to_pickle(self._searches_path()) 2890 meta_fields = ["time"] 2891 columns = [("meta", field) for field in meta_fields] 2892 predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 2893 predictions.to_pickle(self._predictions_path()) 2894 with open(os.path.join(config_file, "data.yaml")) as f: 2895 data_keys = list(YAML().load(f).keys()) 2896 saved_data = pd.DataFrame(columns=data_keys) 2897 saved_data.to_pickle(self._saved_datasets_path()) 2898 pd.DataFrame().to_pickle(self._thresholds_path()) 2899 # with open(self._version_path()) as f: 2900 # f.write(__version__) 2901 2902 def _open_yaml(self, path: str) -> CommentedMap: 2903 """ 2904 Load a parameter dictionary from a .yaml file 2905 """ 2906 2907 with open(path, encoding="utf-8") as f: 2908 data = YAML().load(f) 2909 if data is None: 2910 data = {} 2911 return data 2912 2913 def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7): 2914 """ 2915 Compare nested dictionaries with 'almost equal' condition 2916 """ 2917 2918 ok = True 2919 if u.keys() != d.keys(): 2920 ok = False 2921 else: 2922 for k, v in u.items(): 2923 if isinstance(v, Mapping): 2924 ok = self._compare(d[k], v, allow_diff=allow_diff) 2925 else: 2926 if isinstance(v, float) or isinstance(d[k], float): 2927 if not isinstance(d[k], float) and not isinstance(d[k], int): 2928 ok = False 2929 elif not isinstance(v, float) and not isinstance(v, int): 2930 ok = False 2931 elif np.abs(v - d[k]) > allow_diff: 2932 ok = False 2933 elif v != d[k]: 2934 ok = False 2935 return ok 2936 2937 def _check_comment(self, comment_sequence: List) -> bool: 2938 """ 2939 Check if a comment already exists in a ruamel.yaml comment sequence 2940 """ 2941 2942 if comment_sequence is None: 2943 return False 2944 c = self._get_comment(comment_sequence) 2945 if c != "": 2946 return True 2947 else: 2948 return False 2949 2950 def _get_comment(self, comment_sequence: List, strip=True) -> str: 2951 """ 2952 Get the comment string from a ruamel.yaml comment sequence 2953 """ 2954 2955 if comment_sequence is None: 2956 return "" 2957 c = "" 2958 for cm in comment_sequence: 2959 if cm is not None: 2960 if isinstance(cm, Iterable): 2961 for c in cm: 2962 if c is not None: 2963 c = c.value 2964 break 2965 break 2966 else: 2967 c = cm.value 2968 break 2969 if strip: 2970 c = c.strip() 2971 return c 2972 2973 def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]): 2974 """ 2975 Update a nested dictionary 2976 """ 2977 2978 if "general" in u and "model_name" in u["general"] and "model" in d: 2979 model_name = u["general"]["model_name"] 2980 if d["general"]["model_name"] != model_name: 2981 d["model"] = self._open_yaml( 2982 os.path.join( 2983 self.project_path, "config", "model", f"{model_name}.yaml" 2984 ) 2985 ) 2986 d_copied = deepcopy(d) 2987 for k, v in u.items(): 2988 if ( 2989 k in d_copied 2990 and isinstance(d_copied[k], list) 2991 and isinstance(v, Mapping) 2992 and all([isinstance(x, int) for x in v.keys()]) 2993 ): 2994 for kk, vv in v.items(): 2995 d_copied[k][kk] = vv 2996 elif ( 2997 isinstance(v, Mapping) 2998 and k in d_copied 2999 and isinstance(d_copied[k], Mapping) 3000 ): 3001 if d_copied[k] is None: 3002 d_k = CommentedMap() 3003 else: 3004 d_k = d_copied[k] 3005 d_copied[k] = self._update(d_k, v) 3006 else: 3007 d_copied[k] = v 3008 if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None: 3009 c = self._get_comment(u.ca.items.get(k), strip=False) 3010 if isinstance(d_copied, CommentedMap) and not self._check_comment( 3011 d_copied.ca.items.get(k) 3012 ): 3013 d_copied.yaml_add_eol_comment(c, key=k) 3014 return d_copied 3015 3016 def _update_with_search( 3017 self, 3018 d: Dict, 3019 search_name: str, 3020 load_parameters: list = None, 3021 round_to_binary: list = None, 3022 ): 3023 """ 3024 Update a dictionary with best parameters from a hyperparameter search 3025 """ 3026 3027 u, _ = self._searches().get_best_params( 3028 search_name, load_parameters, round_to_binary 3029 ) 3030 return self._update(d, u) 3031 3032 def _read_parameters(self, catch_blanks=True) -> Dict: 3033 """ 3034 Compose a parameter dictionary to create a task from the config files 3035 """ 3036 3037 config_path = os.path.join(self.project_path, "config") 3038 keys = [ 3039 "data", 3040 "general", 3041 "losses", 3042 "metrics", 3043 "ssl", 3044 "training", 3045 ] 3046 parameters = {} 3047 for key in keys: 3048 parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml")) 3049 features = parameters["general"]["feature_extraction"] 3050 parameters["features"] = self._open_yaml( 3051 os.path.join(config_path, "features", f"{features}.yaml") 3052 ) 3053 transformer = options.extractor_to_transformer[features] 3054 parameters["augmentations"] = self._open_yaml( 3055 os.path.join(config_path, "augmentations", f"{transformer}.yaml") 3056 ) 3057 model = parameters["general"]["model_name"] 3058 parameters["model"] = self._open_yaml( 3059 os.path.join(config_path, "model", f"{model}.yaml") 3060 ) 3061 # input = parameters["general"]["input"] 3062 # parameters["model"] = self._open_yaml( 3063 # os.path.join(config_path, "model", f"{model}.yaml") 3064 # ) 3065 if catch_blanks: 3066 blanks = self._get_blanks() 3067 if len(blanks) > 0: 3068 self.list_blanks() 3069 raise ValueError( 3070 f"Please fill in all the blanks before running experiments" 3071 ) 3072 return parameters 3073 3074 def set_main_parameters(self, model_name: str = None, metric_names: List = None): 3075 """ 3076 Select the model and the metrics 3077 3078 Parameters 3079 ---------- 3080 model_name : str, optional 3081 model name; run `project.help("model") to find out more 3082 metric_names : list, optional 3083 a list of metric function names; run `project.help("metrics") to find out more 3084 """ 3085 3086 pars = {"general": {}} 3087 if model_name is not None: 3088 assert model_name in options.models 3089 pars["general"]["model_name"] = model_name 3090 if metric_names is not None: 3091 for metric in metric_names: 3092 assert metric in options.metrics 3093 pars["general"]["metric_functions"] = metric_names 3094 self.update_parameters(pars) 3095 3096 def help(self, keyword: str = None): 3097 """ 3098 Get information on available options 3099 3100 Parameters 3101 ---------- 3102 keyword : str, optional 3103 the keyword for options (run without arguments to see which keywords are available) 3104 3105 """ 3106 3107 if keyword is None: 3108 print("AVAILABLE HELP FUNCTIONS:") 3109 print("- Try running `project.help(keyword)` with the following keywords:") 3110 print(" - model: to get more information on available models,") 3111 print( 3112 " - features: to get more information on available feature extraction modes," 3113 ) 3114 print( 3115 " - partition_method: to get more information on available train/test/val partitioning methods," 3116 ) 3117 print(" - metrics: to see a list of available metric functions.") 3118 print(" - data: to see help for expected data structure") 3119 print( 3120 "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in." 3121 ) 3122 print( 3123 "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify" 3124 ) 3125 print( 3126 f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)." 3127 ) 3128 elif keyword == "model": 3129 print("MODELS:") 3130 for key, model in options.models.items(): 3131 print(f"{key}:") 3132 print(model.__doc__) 3133 elif keyword == "features": 3134 print("FEATURE EXTRACTORS:") 3135 for key, extractor in options.feature_extractors.items(): 3136 print(f"{key}:") 3137 print(extractor.__doc__) 3138 elif keyword == "partition_method": 3139 print("PARTITION METHODS:") 3140 print( 3141 BehaviorDataset.partition_train_test_val.__doc__.split( 3142 "The partitioning method:" 3143 )[1].split("val_frac :")[0] 3144 ) 3145 elif keyword == "metrics": 3146 print("METRICS:") 3147 for key, metric in options.metrics.items(): 3148 print(f"{key}:") 3149 print(metric.__doc__) 3150 elif keyword == "data": 3151 print("DATA:") 3152 print(f"Video data: {self.data_type}") 3153 print(options.input_stores[self.data_type].__doc__) 3154 print(f"Annotation data: {self.annotation_type}") 3155 print(options.annotation_stores[self.annotation_type].__doc__) 3156 print( 3157 "Annotation path and data path don't have to be separate, you can keep everything in one folder." 3158 ) 3159 else: 3160 raise ValueError(f"The {keyword} keyword is not recognized") 3161 print("\n") 3162 3163 def _process_value(self, value): 3164 if isinstance(value, str): 3165 value = f'"{value}"' 3166 elif isinstance(value, CommentedSet): 3167 value = {x for x in value} 3168 return value 3169 3170 def _get_blanks(self): 3171 caught = [] 3172 parameters = self._read_parameters(catch_blanks=False) 3173 for big_key, big_value in parameters.items(): 3174 for key, value in big_value.items(): 3175 if value == "???": 3176 caught.append( 3177 (big_key, key, self._get_comment(big_value.ca.items.get(key))) 3178 ) 3179 return caught 3180 3181 def list_blanks(self, blanks=None): 3182 """ 3183 List parameters that need to be filled in 3184 3185 Parameters 3186 ---------- 3187 blanks : list, optional 3188 a list of the parameters to list, if already known 3189 """ 3190 3191 if blanks is None: 3192 blanks = self._get_blanks() 3193 if len(blanks) > 0: 3194 to_update = defaultdict(lambda: []) 3195 for b, k, c in blanks: 3196 to_update[b].append((k, c)) 3197 print("Before running experiments, please update all the blanks.") 3198 print("To do that, you can run this.") 3199 print("--------------------------------------------------------") 3200 print(f"project.update_parameters(") 3201 print(f" {{") 3202 for big_key, keys in to_update.items(): 3203 print(f' "{big_key}": {{') 3204 for key, comment in keys: 3205 print(f' "{key}": ..., {comment}') 3206 print(f" }},") 3207 print(f" }}") 3208 print(")") 3209 print("--------------------------------------------------------") 3210 print("Replace ... with relevant values.") 3211 else: 3212 print("There is no blanks left!") 3213 3214 def list_basic_parameters( 3215 self, 3216 ): 3217 """ 3218 Get a list of most relevant parameters and code to modify them 3219 """ 3220 3221 parameters = self._read_parameters() 3222 print("BASIC PARAMETERS:") 3223 model_name = parameters["general"]["model_name"] 3224 metric_names = parameters["general"]["metric_functions"] 3225 loss_name = parameters["general"]["loss_function"] 3226 feature_extraction = parameters["general"]["feature_extraction"] 3227 print("Here is a list of current parameters.") 3228 print( 3229 "You can copy this code, change the parameters you want to set and run it to update the project config." 3230 ) 3231 print("--------------------------------------------------------") 3232 print("project.update_parameters(") 3233 print(" {") 3234 for group in ["general", "data", "training"]: 3235 print(f' "{group}": {{') 3236 for key in options.basic_parameters[group]: 3237 if key in parameters[group]: 3238 print( 3239 f' "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}' 3240 ) 3241 print(" },") 3242 print(' "losses": {') 3243 print(f' "{loss_name}": {{') 3244 for key in options.basic_parameters["losses"][loss_name]: 3245 if key in parameters["losses"][loss_name]: 3246 print( 3247 f' "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}' 3248 ) 3249 print(" },") 3250 print(" },") 3251 print(' "metrics": {') 3252 for metric in metric_names: 3253 print(f' "{metric}": {{') 3254 for key in parameters["metrics"][metric]: 3255 print( 3256 f' "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}' 3257 ) 3258 print(" },") 3259 print(" },") 3260 print(' "model": {') 3261 for key in options.basic_parameters["model"][model_name]: 3262 if key in parameters["model"]: 3263 print( 3264 f' "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}' 3265 ) 3266 3267 print(" },") 3268 print(' "features": {') 3269 for key in options.basic_parameters["features"][feature_extraction]: 3270 if key in parameters["features"]: 3271 print( 3272 f' "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}' 3273 ) 3274 3275 print(" },") 3276 print(' "augmentations": {') 3277 for key in options.basic_parameters["augmentations"][feature_extraction]: 3278 if key in parameters["augmentations"]: 3279 print( 3280 f' "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}' 3281 ) 3282 print(" },") 3283 print(" },") 3284 print(")") 3285 print("--------------------------------------------------------") 3286 print("\n") 3287 3288 def _create_record( 3289 self, 3290 episode_name: str, 3291 behaviors_dict: Dict, 3292 load_episode: str = None, 3293 parameters_update: Dict = None, 3294 task: TaskDispatcher = None, 3295 load_epoch: int = None, 3296 load_search: str = None, 3297 load_parameters: list = None, 3298 round_to_binary: list = None, 3299 load_strict: bool = True, 3300 n_seeds: int = 1, 3301 ) -> TaskDispatcher: 3302 """ 3303 Create a meta data episode record 3304 """ 3305 3306 if episode_name in self._episodes().data.index: 3307 return 3308 if type(n_seeds) is not int or n_seeds < 1: 3309 raise ValueError( 3310 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 3311 ) 3312 if parameters_update is None: 3313 parameters_update = {} 3314 parameters = self._read_parameters() 3315 parameters = self._update(parameters, parameters_update) 3316 if load_search is not None: 3317 parameters = self._update_with_search( 3318 parameters, load_search, load_parameters, round_to_binary 3319 ) 3320 parameters = self._fill( 3321 parameters, 3322 episode_name, 3323 load_episode, 3324 load_epoch=load_epoch, 3325 only_load_model=True, 3326 load_strict=load_strict, 3327 continuing=True, 3328 ) 3329 self._save_episode(episode_name, parameters, behaviors_dict) 3330 return task 3331 3332 def _save_thresholds( 3333 self, 3334 episode_names: List, 3335 metric_name: str, 3336 parameters: Dict, 3337 thresholds: List, 3338 load_epochs: List, 3339 ): 3340 """ 3341 Save optimal decision thresholds in the meta records 3342 """ 3343 3344 metric_parameters = parameters["metrics"][metric_name] 3345 self._thresholds().save_thresholds( 3346 episode_names, load_epochs, metric_name, metric_parameters, thresholds 3347 ) 3348 3349 def _save_episode( 3350 self, 3351 episode_name: str, 3352 parameters: Dict, 3353 behaviors_dict: Dict, 3354 suppress_validation: bool = False, 3355 training_time: str = None, 3356 norm_stats: Dict = None, 3357 ) -> None: 3358 """ 3359 Save an episode in the meta files 3360 """ 3361 3362 try: 3363 split_info = self._split_info_from_filename( 3364 parameters["training"]["split_path"] 3365 ) 3366 parameters["training"]["partition_method"] = split_info["partition_method"] 3367 except: 3368 pass 3369 if norm_stats is not None: 3370 norm_stats = dict(norm_stats) 3371 parameters["training"]["stats"] = norm_stats 3372 self._episodes().save_episode( 3373 episode_name, 3374 parameters, 3375 behaviors_dict, 3376 suppress_validation=suppress_validation, 3377 training_time=training_time, 3378 ) 3379 3380 def _update_episode_results( 3381 self, 3382 episode_name: str, 3383 logs: Tuple, 3384 training_time: str = None, 3385 ) -> None: 3386 """ 3387 Save the results of a run in the meta files 3388 """ 3389 3390 self._episodes().update_episode_results(episode_name, logs, training_time) 3391 3392 def _save_prediction( 3393 self, 3394 episode_name: str, 3395 parameters: Dict, 3396 behaviors_dict: Dict, 3397 embedding: bool = False, 3398 inference_time: str = None, 3399 ) -> None: 3400 """ 3401 Save a prediction in the meta files 3402 """ 3403 3404 parameters = self._update( 3405 parameters, 3406 {"meta": {"embedding": embedding, "inference_time": inference_time}}, 3407 ) 3408 self._predictions().save_episode(episode_name, parameters, behaviors_dict) 3409 3410 def _save_search( 3411 self, 3412 search_name: str, 3413 parameters: Dict, 3414 n_trials: int, 3415 best_params: Dict, 3416 best_value: float, 3417 metric: str, 3418 search_space: Dict, 3419 ) -> None: 3420 """ 3421 Save a hyperparameter search in the meta files 3422 """ 3423 3424 self._searches().save_search( 3425 search_name, 3426 parameters, 3427 n_trials, 3428 best_params, 3429 best_value, 3430 metric, 3431 search_space, 3432 ) 3433 3434 def _save_stores(self, parameters: Dict) -> None: 3435 """ 3436 Save a pickled dataset in the meta files 3437 """ 3438 3439 name = os.path.basename(parameters["data"]["feature_save_path"]) 3440 self._saved_datasets().save_store(name, self._get_data_pars(parameters)) 3441 self.create_metadata_backup() 3442 3443 def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None: 3444 """ 3445 Remove the pre-computed features folder 3446 """ 3447 3448 name = os.path.basename(parameters["data"]["feature_save_path"]) 3449 if remove_active or name not in self._episodes().get_active_datasets(): 3450 self.remove_saved_features([name]) 3451 3452 def _check_episode_validity( 3453 self, episode_name: str, allow_doublecolon: bool = False, force: bool = False 3454 ) -> None: 3455 """ 3456 Check whether the episode name is valid 3457 """ 3458 3459 if episode_name.startswith("_"): 3460 raise ValueError( 3461 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 3462 ) 3463 elif "." in episode_name: 3464 raise ValueError("Names containing '.' cannot be used!") 3465 if not allow_doublecolon and "::" in episode_name: 3466 raise ValueError( 3467 "Names containing '::' are reserved by dlc2action and cannot be used!" 3468 ) 3469 if force: 3470 self.remove_episode(episode_name) 3471 elif not self._episodes().check_name_validity(episode_name): 3472 raise ValueError( 3473 f"The {episode_name} name is already taken! Set force=True to overwrite." 3474 ) 3475 3476 def _check_search_validity(self, search_name: str, force: bool = False) -> None: 3477 """ 3478 Check whether the search name is valid 3479 """ 3480 3481 if search_name.startswith("_"): 3482 raise ValueError( 3483 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 3484 ) 3485 elif "." in search_name: 3486 raise ValueError("Names containing '.' cannot be used!") 3487 if force: 3488 self.remove_search(search_name) 3489 elif not self._searches().check_name_validity(search_name): 3490 raise ValueError(f"The {search_name} name is already taken!") 3491 3492 def _check_prediction_validity( 3493 self, prediction_name: str, force: bool = False 3494 ) -> None: 3495 """ 3496 Check whether the prediction name is valid 3497 """ 3498 3499 if prediction_name.startswith("_"): 3500 raise ValueError( 3501 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 3502 ) 3503 elif "." in prediction_name: 3504 raise ValueError("Names containing '.' cannot be used!") 3505 if force: 3506 self.remove_prediction(prediction_name) 3507 elif not self._predictions().check_name_validity(prediction_name): 3508 raise ValueError(f"The {prediction_name} name is already taken!") 3509 3510 def _training_time(self, episode_name: str) -> int: 3511 """ 3512 Get the training time of an episode in seconds 3513 """ 3514 3515 return self._episode(episode_name).training_time() 3516 3517 def _mask_path(self) -> str: 3518 """ 3519 Get the path to the masks folder 3520 """ 3521 3522 return os.path.join(self.project_path, "results", "masks") 3523 3524 def _thresholds_path(self) -> str: 3525 """ 3526 Get the path to the thresholds meta file 3527 """ 3528 3529 return os.path.join(self.project_path, "meta", "thresholds.pickle") 3530 3531 def _episodes_path(self) -> str: 3532 """ 3533 Get the path to the episodes meta file 3534 """ 3535 3536 return os.path.join(self.project_path, "meta", "episodes.pickle") 3537 3538 def _saved_datasets_path(self) -> str: 3539 """ 3540 Get the path to the datasets meta file 3541 """ 3542 3543 return os.path.join(self.project_path, "meta", "saved_datasets.pickle") 3544 3545 def _predictions_path(self) -> str: 3546 """ 3547 Get the path to the predictions meta file 3548 """ 3549 3550 return os.path.join(self.project_path, "meta", "predictions.pickle") 3551 3552 def _dataset_store_path(self, name: str) -> str: 3553 """ 3554 Get the path to a specific pickled dataset 3555 """ 3556 3557 return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle") 3558 3559 def _searches_path(self) -> str: 3560 """ 3561 Get the path to the hyperparameter search meta file 3562 """ 3563 3564 return os.path.join(self.project_path, "meta", "searches.pickle") 3565 3566 def _search_path(self, name: str) -> str: 3567 """ 3568 Get the default path to the graph folder for a specific hyperparameter search 3569 """ 3570 3571 return os.path.join(self.project_path, "results", "searches", name) 3572 3573 def _version_path(self) -> str: 3574 """ 3575 Get the path to the version file 3576 """ 3577 3578 return os.path.join(self.project_path, "meta", "version.txt") 3579 3580 def _default_split_file(self, split_info: Dict) -> Optional[str]: 3581 """ 3582 Generate a path to a split file from split parameters 3583 """ 3584 3585 if split_info["partition_method"].startswith("time"): 3586 return None 3587 val_frac = split_info["val_frac"] 3588 test_frac = split_info["test_frac"] 3589 split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}' 3590 if not split_info["only_load_annotated"]: 3591 split_name += "_all" 3592 split_name += ".txt" 3593 return os.path.join(self.project_path, "results", "splits", split_name) 3594 3595 def _split_info_from_filename(self, split_name: str) -> Dict: 3596 """ 3597 Get split parameters from default path to a split file 3598 """ 3599 3600 if split_name is None: 3601 return {} 3602 try: 3603 name = os.path.basename(split_name)[:-4] 3604 split = name.split("_") 3605 if len(split) == 6: 3606 only_load_annotated = False 3607 else: 3608 only_load_annotated = True 3609 len_segment = int(split[3][3:]) 3610 overlap = int(split[4][7:]) 3611 method, val, test = split[:3] 3612 val = float(val[3:-1]) / 100 3613 test = float(test[4:-1]) / 100 3614 return { 3615 "partition_method": method, 3616 "val_frac": val, 3617 "test_frac": test, 3618 "only_load_annotated": only_load_annotated, 3619 "len_segment": len_segment, 3620 "overlap": overlap, 3621 } 3622 except: 3623 return {"partition_method": "file"} 3624 3625 def _fill( 3626 self, 3627 parameters: Dict, 3628 episode_name: str, 3629 load_experiment: str = None, 3630 load_epoch: int = None, 3631 load_strict: bool = True, 3632 only_load_model: bool = False, 3633 continuing: bool = False, 3634 enforce_split_parameters: bool = False, 3635 ) -> Dict: 3636 """ 3637 Update the parameters from the config files with project specific information 3638 3639 Fill in the constant file path parameters and generate a unique log file and a model folder. 3640 Fill in the split file if the same split has been run before in the project and change partition method to 3641 from_file. 3642 Fill in saved data path if a dataset with the same data parameters already exists in the project. 3643 If load_experiment is not None, fill in the checkpoint path as well. 3644 The only_load_model training parameter is defined by the corresponding argument. 3645 If continuing is True, new files are not created and all information is loaded from load_experiment. 3646 If prediction is True, log and model files are not created. 3647 The enforce_split_parameters parameter is used to resolve conflicts 3648 between split file path and split parameters when they arise. 3649 """ 3650 3651 pars = deepcopy(parameters) 3652 if episode_name == "_": 3653 self.remove_episode("_") 3654 log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt") 3655 model_save_path = os.path.join( 3656 self.project_path, "results", "model", episode_name 3657 ) 3658 if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)): 3659 raise ValueError( 3660 f"The {episode_name} episode name is already in use! Set force=True to overwrite." 3661 ) 3662 keys = ["val_frac", "test_frac", "partition_method"] 3663 if "len_segment" not in pars["general"] and "len_segment" in pars["data"]: 3664 pars["general"]["len_segment"] = pars["data"]["len_segment"] 3665 if "overlap" not in pars["general"] and "overlap" in pars["data"]: 3666 pars["general"]["overlap"] = pars["data"]["overlap"] 3667 if "len_segment" in pars["data"]: 3668 pars["data"].pop("len_segment") 3669 if "overlap" in pars["data"]: 3670 pars["data"].pop("overlap") 3671 split_info = {k: pars["training"][k] for k in keys} 3672 split_info["only_load_annotated"] = pars["general"]["only_load_annotated"] 3673 split_info["len_segment"] = pars["general"]["len_segment"] 3674 split_info["overlap"] = pars["general"]["overlap"] 3675 pars["training"]["log_file"] = log 3676 if not os.path.exists(model_save_path): 3677 os.mkdir(model_save_path) 3678 pars["training"]["model_save_path"] = model_save_path 3679 if load_experiment is not None: 3680 if load_experiment not in self._episodes().data.index: 3681 raise ValueError(f"The {load_experiment} episode does not exist!") 3682 old_episode = self._episode(load_experiment) 3683 old_file = old_episode.split_file() 3684 old_info = self._split_info_from_filename(old_file) 3685 if len(old_info) == 0: 3686 old_info = old_episode.split_info() 3687 if enforce_split_parameters: 3688 if split_info["partition_method"] != "file": 3689 pars["training"]["split_path"] = self._default_split_file( 3690 split_info 3691 ) 3692 else: 3693 equal = True 3694 if old_info["partition_method"] != split_info["partition_method"]: 3695 equal = False 3696 if old_info["partition_method"] != "file": 3697 if ( 3698 old_info["val_frac"] != split_info["val_frac"] 3699 or old_info["test_frac"] != split_info["test_frac"] 3700 ): 3701 equal = False 3702 if not continuing and not equal: 3703 warnings.warn( 3704 f"The partitioning parameters in the loaded experiment ({old_info}) " 3705 f"are not equal to the current partitioning parameters ({split_info}). " 3706 f"The current parameters are replaced." 3707 ) 3708 pars["training"]["split_path"] = old_file 3709 pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch) 3710 pars["training"]["load_strict"] = load_strict 3711 else: 3712 pars["training"]["checkpoint_path"] = None 3713 if pars["training"]["partition_method"] == "file": 3714 if ( 3715 "split_path" not in pars["training"] 3716 or pars["training"]["split_path"] is None 3717 ): 3718 raise ValueError( 3719 "The partition_method parameter is set to file but the " 3720 "split_path parameter is not set!" 3721 ) 3722 elif not os.path.exists(pars["training"]["split_path"]): 3723 raise ValueError( 3724 f'The {pars["training"]["split_path"]} split file does not exist' 3725 ) 3726 else: 3727 pars["training"]["split_path"] = self._default_split_file(split_info) 3728 pars["training"]["only_load_model"] = only_load_model 3729 pars["data"]["saved_data_path"] = None 3730 pars["data"]["feature_save_path"] = None 3731 pars_data_copy = self._get_data_pars(pars) 3732 saved_data_name = self._saved_datasets().find_name(pars_data_copy) 3733 if saved_data_name is not None: 3734 pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name) 3735 pars["data"]["feature_save_path"] = self._dataset_store_path( 3736 saved_data_name 3737 ).split(".")[0] 3738 else: 3739 dataset_path = self._dataset_store_path(episode_name) 3740 if os.path.exists(dataset_path): 3741 name, ext = dataset_path.split(".") 3742 i = 0 3743 while os.path.exists(f"{name}_{i}.{ext}"): 3744 i += 1 3745 dataset_path = f"{name}_{i}.{ext}" 3746 pars["data"]["saved_data_path"] = dataset_path 3747 pars["data"]["feature_save_path"] = dataset_path.split(".")[0] 3748 split_split = pars["training"]["partition_method"].split(":") 3749 random = True 3750 for partition_method in options.partition_methods["fixed"]: 3751 method_split = partition_method.split(":") 3752 if len(split_split) != len(method_split): 3753 continue 3754 equal = True 3755 for x, y in zip(split_split, method_split): 3756 if y.startswith("{"): 3757 continue 3758 if x != y: 3759 equal = False 3760 break 3761 if equal: 3762 random = False 3763 break 3764 if random and os.path.exists(pars["training"]["split_path"]): 3765 pars["training"]["partition_method"] = "file" 3766 pars["general"]["save_dataset"] = True 3767 return pars 3768 3769 def _get_data_pars(self, pars: Dict) -> Dict: 3770 """ 3771 Get a complete description of the data from a general parameters dictionary 3772 """ 3773 3774 pars_data_copy = deepcopy(pars["data"]) 3775 for par in [ 3776 "only_load_annotated", 3777 "exclusive", 3778 "feature_extraction", 3779 "ignored_clips", 3780 "len_segment", 3781 "overlap", 3782 ]: 3783 pars_data_copy[par] = pars["general"].get(par, None) 3784 pars_data_copy.update(pars["features"]) 3785 return pars_data_copy 3786 3787 def count_classes( 3788 self, 3789 load_episode: str = None, 3790 parameters_update: Dict = None, 3791 remove_saved_features: bool = False, 3792 bouts: bool = True, 3793 ) -> Dict: 3794 """ 3795 Get a dictionary of class counts in different modes 3796 3797 Parameters 3798 ---------- 3799 load_episode : str, optional 3800 the episode settings to load 3801 parameters_update : dict, optional 3802 a dictionary of parameter updates (only for "data" and "general" categories) 3803 remove_saved_features : bool, default False 3804 if `True`, the dataset that is used for computation is then deleted 3805 bouts : bool, default False 3806 if `True`, instead of frame counts segment counts are returned 3807 3808 Returns 3809 ------- 3810 class_counts : dict 3811 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 3812 class names and values are class counts (in frames) 3813 """ 3814 3815 if load_episode is None: 3816 task, parameters = self._make_task_training( 3817 episode_name="_", parameters_update=parameters_update, throwaway=True 3818 ) 3819 else: 3820 task, parameters, _ = self._make_task_prediction( 3821 "_", 3822 load_episode=load_episode, 3823 parameters_update=parameters_update, 3824 ) 3825 class_counts = task.count_classes(bouts=bouts) 3826 behaviors = task.behaviors_dict() 3827 class_counts = { 3828 kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()} 3829 for kk, vv in class_counts.items() 3830 } 3831 if remove_saved_features: 3832 self._remove_stores(parameters) 3833 return class_counts 3834 3835 def plot_class_distribution( 3836 self, 3837 parameters_update: Dict = None, 3838 frame_cutoff: int = 1, 3839 bout_cutoff: int = 1, 3840 print_full: bool = False, 3841 remove_saved_features: bool = False, 3842 ) -> None: 3843 """ 3844 Make a class distribution plot 3845 3846 You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset 3847 is created or laoded for the computation with the default parameters). 3848 3849 Parameters 3850 ---------- 3851 parameters_update : dict, optional 3852 a dictionary of parameter updates (only for "data" and "general" categories) 3853 remove_saved_features : bool, default False 3854 if `True`, the dataset that is used for computation is then deleted 3855 """ 3856 3857 task, parameters = self._make_task_training( 3858 episode_name="_", parameters_update=parameters_update, throwaway=True 3859 ) 3860 cutoff = {True: bout_cutoff, False: frame_cutoff} 3861 for bouts in [True, False]: 3862 class_counts = task.count_classes(bouts=bouts) 3863 if print_full: 3864 print("Bouts:" if bouts else "Frames:") 3865 for k, v in class_counts.items(): 3866 if sum(v.values()) != 0: 3867 print(f" {k}:") 3868 values, keys = zip( 3869 *[ 3870 x 3871 for x in sorted(zip(v.values(), v.keys()), reverse=True) 3872 if x[-1] != -100 3873 ] 3874 ) 3875 for kk, vv in zip(keys, values): 3876 print(f" {task.behaviors_dict()[kk]}: {vv}") 3877 class_counts = { 3878 kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]} 3879 for kk, vv in class_counts.items() 3880 } 3881 for key, d in class_counts.items(): 3882 if sum(d.values()) != 0: 3883 values, keys = zip( 3884 *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100] 3885 ) 3886 keys = [task.behaviors_dict()[x] for x in keys] 3887 plt.bar(keys, values) 3888 plt.title(key) 3889 plt.xticks(rotation=45, ha="right") 3890 if bouts: 3891 plt.ylabel("bouts") 3892 else: 3893 plt.ylabel("frames") 3894 plt.tight_layout() 3895 plt.show() 3896 if remove_saved_features: 3897 self._remove_stores(parameters) 3898 3899 def _generate_mask( 3900 self, 3901 mask_name: str, 3902 perc_annotated: float = 0.1, 3903 parameters_update: Dict = None, 3904 remove_saved_features: bool = False, 3905 ) -> None: 3906 """ 3907 Generate a real_lens for active learning simulation 3908 3909 Parameters 3910 ---------- 3911 mask_name : str 3912 the name of the real_lens 3913 """ 3914 3915 print(f"GENERATING {mask_name}") 3916 task, parameters = self._make_task_training( 3917 f"_{mask_name}", parameters_update=parameters_update, throwaway=True 3918 ) 3919 val_intervals, val_ids = task.dataset("val").get_intervals() # 1 3920 unannotated_intervals = task.dataset("train").get_unannotated_intervals() # 2 3921 unannotated_intervals = task.dataset("val").get_unannotated_intervals( 3922 first_intervals=unannotated_intervals 3923 ) 3924 ids = task.dataset("train").get_ids() 3925 mask = {video_id: {} for video_id in ids} 3926 total_all = 0 3927 total_masked = 0 3928 for video_id, clip_ids in ids.items(): 3929 for clip_id in clip_ids: 3930 frames = np.ones(task.dataset("train").get_len(video_id, clip_id)) 3931 if clip_id in val_intervals[video_id]: 3932 for start, end in val_intervals[video_id][clip_id]: 3933 frames[start:end] = 0 3934 if clip_id in unannotated_intervals[video_id]: 3935 for start, end in unannotated_intervals[video_id][clip_id]: 3936 frames[start:end] = 0 3937 annotated = np.where(frames)[0] 3938 total_all += len(annotated) 3939 masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :] 3940 total_masked += len(masked) 3941 mask[video_id][clip_id] = self._get_intervals(masked) 3942 file = { 3943 "masked": mask, 3944 "val_intervals": val_intervals, 3945 "val_ids": val_ids, 3946 "unannotated": unannotated_intervals, 3947 } 3948 self._save_mask(file, mask_name) 3949 if remove_saved_features: 3950 self._remove_stores(parameters) 3951 print("\n") 3952 # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames') 3953 3954 def _get_intervals(self, frame_indices: np.ndarray): 3955 """ 3956 Get a list of intervals from a list of frame indices 3957 3958 Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`. 3959 3960 Parameters 3961 ---------- 3962 frame_indices : np.ndarray 3963 a list of frame indices 3964 3965 Returns 3966 ------- 3967 intervals : list 3968 a list of interval boundaries 3969 """ 3970 3971 masked_intervals = [] 3972 if len(frame_indices) > 0: 3973 breaks = np.where(np.diff(frame_indices) != 1)[0] 3974 start = frame_indices[0] 3975 for k in breaks: 3976 masked_intervals.append([start, frame_indices[k] + 1]) 3977 start = frame_indices[k + 1] 3978 masked_intervals.append([start, frame_indices[-1] + 1]) 3979 return masked_intervals 3980 3981 def _update_mask_with_uncertainty( 3982 self, 3983 mask_name: str, 3984 episode_name: Union[str, None], 3985 classes: List, 3986 load_epoch: int = None, 3987 n_frames: int = 10000, 3988 method: str = "least_confidence", 3989 min_length: int = 30, 3990 augment_n: int = 0, 3991 parameters_update: Dict = None, 3992 ): 3993 """ 3994 Update real_lens with frame-wise uncertainty scores for active learning 3995 3996 Parameters 3997 ---------- 3998 mask_name : str 3999 the name of the real_lens 4000 episode_name : str 4001 the name of the episode to load 4002 classes : list 4003 a list of class names or indices; their uncertainty scores will be computed separately and stacked 4004 n_frames : int, default 10000 4005 the number of frames to "annotate" 4006 method : {"least_confidence", "entropy"} 4007 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 4008 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 4009 min_length : int 4010 the minimum length (in frames) of the annotated intervals 4011 augment_n : int, default 0 4012 the number of augmentations to average over 4013 parameters_update : dict, optional 4014 the dictionary used to update the parameters from the config 4015 4016 Returns 4017 ------- 4018 score_dicts : dict 4019 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 4020 are score tensors 4021 """ 4022 4023 print(f"UPDATING {mask_name}") 4024 task, parameters, _ = self._make_task_prediction( 4025 prediction_name=mask_name, 4026 load_episode=episode_name, 4027 parameters_update=parameters_update, 4028 load_epoch=load_epoch, 4029 mode="train", 4030 ) 4031 score_tensors = task.generate_uncertainty_score(classes, augment_n, method) 4032 self._update_mask(task, mask_name, score_tensors, n_frames, min_length) 4033 print("\n") 4034 4035 def _update_mask_with_BALD( 4036 self, 4037 mask_name: str, 4038 episode_name: str, 4039 classes: List, 4040 load_epoch: int = None, 4041 augment_n: int = 0, 4042 n_frames: int = 10000, 4043 num_models: int = 10, 4044 kernel_size: int = 11, 4045 min_length: int = 30, 4046 parameters_update: Dict = None, 4047 ): 4048 """ 4049 Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning 4050 4051 Parameters 4052 ---------- 4053 mask_name : str 4054 the name of the real_lens 4055 episode_name : str 4056 the name of the episode to load 4057 classes : list 4058 a list of class names or indices; their uncertainty scores will be computed separately and stacked 4059 augment_n : int, default 0 4060 the number of augmentations to average over 4061 n_frames : int, default 10000 4062 the number of frames to "annotate" 4063 num_models : int, default 10 4064 the number of dropout masks to apply 4065 kernel_size : int, default 11 4066 the size of the smoothing gaussian kernel 4067 min_length : int 4068 the minimum length (in frames) of the annotated intervals 4069 parameters_update : dict, optional 4070 the dictionary used to update the parameters from the config 4071 4072 Returns 4073 ------- 4074 score_dicts : dict 4075 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 4076 are score tensors 4077 """ 4078 4079 print(f"UPDATING {mask_name}") 4080 task, parameters, mode = self._make_task_prediction( 4081 mask_name, 4082 load_episode=episode_name, 4083 parameters_update=parameters_update, 4084 load_epoch=load_epoch, 4085 ) 4086 score_tensors = task.generate_bald_score( 4087 classes, augment_n, num_models, kernel_size 4088 ) 4089 self._update_mask(task, mask_name, score_tensors, n_frames, min_length) 4090 print("\n") 4091 4092 def _suggest_intervals( 4093 self, 4094 dataset: BehaviorDataset, 4095 score_tensors: Dict, 4096 n_frames: int, 4097 min_length: int, 4098 ) -> Dict: 4099 """ 4100 Suggest intervals with highest score of total length `n_frames` 4101 4102 Parameters 4103 ---------- 4104 dataset : BehaviorDataset 4105 the dataset 4106 score_tensors : dict 4107 a dictionary where keys are clip ids and values are framewise score tensors 4108 n_frames : int 4109 the number of frames to "annotate" 4110 min_length : int 4111 4112 Returns 4113 ------- 4114 active_learning_intervals : Dict 4115 active learning dictionary with suggested intervals 4116 """ 4117 4118 video_intervals, _ = dataset.get_intervals() 4119 taken = { 4120 video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys() 4121 } 4122 annotated = dataset.get_annotated_intervals() 4123 for video_id in video_intervals: 4124 for clip_id in video_intervals[video_id]: 4125 taken[video_id][clip_id] = torch.zeros( 4126 dataset.get_len(video_id, clip_id) 4127 ) 4128 if video_id in annotated and clip_id in annotated[video_id]: 4129 for start, end in annotated[video_id][clip_id]: 4130 score_tensors[video_id][clip_id][:, start:end] = -10 4131 taken[video_id][clip_id][int(start) : int(end)] = 1 4132 n_frames = ( 4133 sum([(vv == 1).sum() for v in taken.values() for vv in v.values()]) 4134 + n_frames 4135 ) 4136 factor = 1 4137 threshold_start = float( 4138 torch.mean( 4139 torch.tensor( 4140 [ 4141 torch.mean( 4142 torch.tensor([torch.mean(y[y > 0]) for y in x.values()]) 4143 ) 4144 for x in score_tensors.values() 4145 ] 4146 ) 4147 ) 4148 ) 4149 while ( 4150 sum([(vv == 1).sum() for v in taken.values() for vv in v.values()]) 4151 < n_frames 4152 ): 4153 threshold = threshold_start * factor 4154 intervals = [] 4155 interval_scores = [] 4156 key1 = list(score_tensors.keys())[0] 4157 key2 = list(score_tensors[key1].keys())[0] 4158 num_scores = score_tensors[key1][key2].shape[0] 4159 for i in range(num_scores): 4160 v_dict = dataset.find_valleys( 4161 predicted=score_tensors, 4162 threshold=threshold, 4163 min_frames=min_length, 4164 main_class=i, 4165 low=False, 4166 ) 4167 for v_id, interval_list in v_dict.items(): 4168 intervals += [x + [v_id] for x in interval_list] 4169 interval_scores += [ 4170 float(torch.mean(score_tensors[v_id][clip_id][i, start:end])) 4171 for start, end, clip_id in interval_list 4172 ] 4173 intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]] 4174 i = 0 4175 while sum( 4176 [(vv == 1).sum() for v in taken.values() for vv in v.values()] 4177 ) < n_frames and i < len(intervals): 4178 start, end, clip_id, video_id = intervals[i] 4179 i += 1 4180 taken[video_id][clip_id][int(start) : int(end)] = 1 4181 factor *= 0.9 4182 if factor < 0.05: 4183 warnings.warn(f"Could not find enough frames!") 4184 break 4185 active_learning_intervals = {video_id: [] for video_id in video_intervals} 4186 for video_id in taken: 4187 for clip_id in taken[video_id]: 4188 if video_id in annotated and clip_id in annotated[video_id]: 4189 for start, end in annotated[video_id][clip_id]: 4190 taken[video_id][clip_id][int(start) : int(end)] = 0 4191 if (taken[video_id][clip_id] == 1).sum() == 0: 4192 continue 4193 indices = np.where(taken[video_id][clip_id].numpy())[0] 4194 boundaries = self._get_intervals(indices) 4195 active_learning_intervals[video_id] += [ 4196 [start, end, clip_id] for start, end in boundaries 4197 ] 4198 return active_learning_intervals 4199 4200 def _update_mask( 4201 self, 4202 task: TaskDispatcher, 4203 mask_name: str, 4204 score_tensors: Dict, 4205 n_frames: int, 4206 min_length: int, 4207 ) -> None: 4208 """ 4209 Update the real_lens with intervals with the highest score of total length `n_frames` 4210 4211 Parameters 4212 ---------- 4213 mask_name : str 4214 the name of the real_lens 4215 score_tensors : dict 4216 a dictionary where keys are clip ids and values are framewise score tensors 4217 n_frames : int 4218 the number of frames to "annotate" 4219 min_length : int 4220 the minimum length of the annotated intervals 4221 """ 4222 4223 mask = self._load_mask(mask_name) 4224 video_intervals, _ = task.dataset("train").get_intervals() 4225 masked = { 4226 video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys() 4227 } 4228 total_masked = 0 4229 total_all = 0 4230 for video_id in video_intervals: 4231 for clip_id in video_intervals[video_id]: 4232 masked[video_id][clip_id] = torch.zeros( 4233 task.dataset("train").get_len(video_id, clip_id) 4234 ) 4235 if ( 4236 video_id in mask["unannotated"] 4237 and clip_id in mask["unannotated"][video_id] 4238 ): 4239 for start, end in mask["unannotated"][video_id][clip_id]: 4240 score_tensors[video_id][clip_id][:, start:end] = -10 4241 masked[video_id][clip_id][int(start) : int(end)] = 1 4242 if ( 4243 video_id in mask["val_intervals"] 4244 and clip_id in mask["val_intervals"][video_id] 4245 ): 4246 for start, end in mask["val_intervals"][video_id][clip_id]: 4247 score_tensors[video_id][clip_id][:, start:end] = -10 4248 masked[video_id][clip_id][int(start) : int(end)] = 1 4249 total_all += torch.sum(masked[video_id][clip_id] == 0) 4250 if video_id in mask["masked"] and clip_id in mask["masked"][video_id]: 4251 # print(f'{real_lens["masked"][video_id][clip_id]=}') 4252 for start, end in mask["masked"][video_id][clip_id]: 4253 masked[video_id][clip_id][int(start) : int(end)] = 1 4254 total_masked += end - start 4255 old_n_frames = sum( 4256 [(vv == 0).sum() for v in masked.values() for vv in v.values()] 4257 ) 4258 n_frames = old_n_frames + n_frames 4259 factor = 1 4260 while ( 4261 sum([(vv == 0).sum() for v in masked.values() for vv in v.values()]) 4262 < n_frames 4263 ): 4264 threshold = float( 4265 torch.mean( 4266 torch.tensor( 4267 [ 4268 torch.mean( 4269 torch.tensor([torch.mean(y[y > 0]) for y in x.values()]) 4270 ) 4271 for x in score_tensors.values() 4272 ] 4273 ) 4274 ) 4275 ) 4276 threshold = threshold * factor 4277 intervals = [] 4278 interval_scores = [] 4279 key1 = list(score_tensors.keys())[0] 4280 key2 = list(score_tensors[key1].keys())[0] 4281 num_scores = score_tensors[key1][key2].shape[0] 4282 for i in range(num_scores): 4283 v_dict = task.dataset("train").find_valleys( 4284 predicted=score_tensors, 4285 threshold=threshold, 4286 min_frames=min_length, 4287 main_class=i, 4288 low=False, 4289 ) 4290 for v_id, interval_list in v_dict.items(): 4291 intervals += [x + [v_id] for x in interval_list] 4292 interval_scores += [ 4293 float(torch.mean(score_tensors[v_id][clip_id][i, start:end])) 4294 for start, end, clip_id in interval_list 4295 ] 4296 intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]] 4297 i = 0 4298 while sum( 4299 [(vv == 0).sum() for v in masked.values() for vv in v.values()] 4300 ) < n_frames and i < len(intervals): 4301 start, end, clip_id, video_id = intervals[i] 4302 i += 1 4303 masked[video_id][clip_id][int(start) : int(end)] = 0 4304 factor *= 0.9 4305 if factor < 0.05: 4306 warnings.warn(f"Could not find enough frames!") 4307 break 4308 mask["masked"] = {video_id: {} for video_id in video_intervals} 4309 total_masked_new = 0 4310 for video_id in masked: 4311 for clip_id in masked[video_id]: 4312 if ( 4313 video_id in mask["unannotated"] 4314 and clip_id in mask["unannotated"][video_id] 4315 ): 4316 for start, end in mask["unannotated"][video_id][clip_id]: 4317 masked[video_id][clip_id][int(start) : int(end)] = 0 4318 if ( 4319 video_id in mask["val_intervals"] 4320 and clip_id in mask["val_intervals"][video_id] 4321 ): 4322 for start, end in mask["val_intervals"][video_id][clip_id]: 4323 masked[video_id][clip_id][int(start) : int(end)] = 0 4324 indices = np.where(masked[video_id][clip_id].numpy())[0] 4325 mask["masked"][video_id][clip_id] = self._get_intervals(indices) 4326 for video_id in mask["masked"]: 4327 for clip_id in mask["masked"][video_id]: 4328 for start, end in mask["masked"][video_id][clip_id]: 4329 total_masked_new += end - start 4330 self._save_mask(mask, mask_name) 4331 with open( 4332 os.path.join(self.project_path, "results", f"{mask_name}.txt"), "a" 4333 ) as f: 4334 f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n") 4335 print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}") 4336 4337 def plot_confusion_matrix( 4338 self, 4339 episode_name: str, 4340 load_epoch: int = None, 4341 parameters_update: Dict = None, 4342 type: str = "recall", 4343 mode: str = "val", 4344 remove_saved_features: bool = False, 4345 ) -> Tuple[ndarray, Iterable]: 4346 """ 4347 Make a confusion matrix plot and return the data 4348 4349 If the annotation is non-exclusive, only false positive labels are considered. 4350 4351 Parameters 4352 ---------- 4353 episode_name : str 4354 the name of the episode to load 4355 load_epoch : int, optional 4356 the index of the epoch to load (by default the last one is loaded) 4357 parameters_update : dict, optional 4358 a dictionary of parameter updates (only for "data" and "general" categories) 4359 mode : {'val', 'all', 'test', 'train'} 4360 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 4361 type : {"recall", "precision"} 4362 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 4363 into account, and if `type` is `"precision"`, only false negatives 4364 remove_saved_features : bool, default False 4365 if `True`, the dataset that is used for computation is then deleted 4366 4367 Returns 4368 ------- 4369 confusion_matrix : np.ndarray 4370 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 4371 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 4372 `N_i` is the number of frames that have the i-th label in the ground truth 4373 classes : list 4374 a list of labels 4375 """ 4376 4377 task, parameters, mode = self._make_task_prediction( 4378 "_", 4379 load_episode=episode_name, 4380 load_epoch=load_epoch, 4381 parameters_update=parameters_update, 4382 mode=mode, 4383 ) 4384 dataset = task.dataset(mode) 4385 prediction = task.predict(dataset, raw_output=True) 4386 confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type) 4387 if remove_saved_features: 4388 self._remove_stores(parameters) 4389 fig, ax = plt.subplots(figsize=(len(classes), len(classes))) 4390 ax.imshow(confusion_matrix) 4391 # Show all ticks and label them with the respective list entries 4392 ax.set_xticks(np.arange(len(classes))) 4393 ax.set_xticklabels(classes) 4394 ax.set_yticks(np.arange(len(classes))) 4395 ax.set_yticklabels(classes) 4396 # Rotate the tick labels and set their alignment. 4397 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 4398 # Loop over data dimensions and create text annotations. 4399 for i in range(len(classes)): 4400 for j in range(len(classes)): 4401 ax.text( 4402 j, 4403 i, 4404 np.round(confusion_matrix[i, j], 2), 4405 ha="center", 4406 va="center", 4407 color="w", 4408 ) 4409 if type is not None: 4410 ax.set_title(f"{type} {episode_name}") 4411 else: 4412 ax.set_title(episode_name) 4413 fig.tight_layout() 4414 plt.show() 4415 return confusion_matrix, classes 4416 4417 def plot_predictions( 4418 self, 4419 episode_name: str, 4420 load_epoch: int = None, 4421 parameters_update: Dict = None, 4422 add_legend: bool = True, 4423 ground_truth: bool = True, 4424 colormap: str = "viridis", 4425 hide_axes: bool = False, 4426 min_classes: int = 1, 4427 width: float = 10, 4428 whole_video: bool = False, 4429 transparent: bool = False, 4430 drop_classes: Set = None, 4431 search_classes: Set = None, 4432 num_plots: int = 1, 4433 remove_saved_features: bool = False, 4434 smooth_interval_prediction: int = 0, 4435 data_path: str = None, 4436 file_paths: Set = None, 4437 mode: str = "val", 4438 behavior_name: str = None, 4439 ) -> None: 4440 """ 4441 Visualize random predictions 4442 4443 Parameters 4444 ---------- 4445 episode_name : str 4446 the name of the episode to load 4447 load_epoch : int, optional 4448 the epoch to load (by default last) 4449 parameters_update : dict, optional 4450 parameter update dictionary 4451 add_legend : bool, default True 4452 if True, legend will be added to the plot 4453 ground_truth : bool, default True 4454 if True, ground truth will be added to the plot 4455 colormap : str, default 'Accent' 4456 the `matplotlib` colormap to use 4457 hide_axes : bool, default True 4458 if `True`, the axes will be hidden on the plot 4459 min_classes : int, default 1 4460 the minimum number of classes in a displayed interval 4461 width : float, default 10 4462 the width of the plot 4463 whole_video : bool, default False 4464 if `True`, whole videos are plotted instead of segments 4465 transparent : bool, default False 4466 if `True`, the background on the plot is transparent 4467 drop_classes : set, optional 4468 a set of class names to not be displayed 4469 search_classes : set, optional 4470 if given, only intervals where at least one of the classes is in ground truth will be shown 4471 num_plots : int, default 1 4472 the number of plots to make 4473 remove_saved_features : bool, default False 4474 if `True`, the dataset will be deleted after computation 4475 smooth_interval_prediction : int, default 0 4476 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) 4477 data_path : str, optional 4478 the data path to run the prediction for 4479 mode : {'all', 'test', 'val', 'train'} 4480 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 4481 file_paths : set, optional 4482 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 4483 for 4484 behavior_name : str, optional 4485 for non-exclusive classificaton datasets, choose which behavior to visualize (by default first in list) 4486 """ 4487 4488 other_path = os.path.join(self.project_path, "results", "other") 4489 task, parameters, mode = self._make_task_prediction( 4490 "_", 4491 load_episode=episode_name, 4492 parameters_update=parameters_update, 4493 load_epoch=load_epoch, 4494 data_path=data_path, 4495 file_paths=file_paths, 4496 mode=mode, 4497 ) 4498 if not os.path.exists(other_path): 4499 os.mkdir(other_path) 4500 for i in range(num_plots): 4501 task.visualize_results( 4502 save_path=os.path.join( 4503 other_path, f"{episode_name}_prediction_{i}.jpg" 4504 ), 4505 add_legend=add_legend, 4506 ground_truth=ground_truth, 4507 colormap=colormap, 4508 hide_axes=hide_axes, 4509 min_classes=min_classes, 4510 whole_video=whole_video, 4511 transparent=transparent, 4512 dataset=mode, 4513 drop_classes=drop_classes, 4514 search_classes=search_classes, 4515 width=width, 4516 smooth_interval_prediction=smooth_interval_prediction, 4517 behavior_name=behavior_name, 4518 ) 4519 if remove_saved_features: 4520 self._remove_stores(parameters) 4521 4522 def create_metadata_backup(self) -> None: 4523 """ 4524 Create a copy of the meta files 4525 """ 4526 4527 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 4528 meta_path = os.path.join(self.project_path, "meta") 4529 if os.path.exists(meta_copy_path): 4530 shutil.rmtree(meta_copy_path) 4531 os.mkdir(meta_copy_path) 4532 for file in os.listdir(meta_path): 4533 if file == "backup": 4534 continue 4535 shutil.copy( 4536 os.path.join(meta_path, file), os.path.join(meta_copy_path, file) 4537 ) 4538 4539 def load_metadata_backup(self) -> None: 4540 """ 4541 Load from previously created meta data backup (in case of corruption) 4542 """ 4543 4544 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 4545 meta_path = os.path.join(self.project_path, "meta") 4546 for file in os.listdir(meta_copy_path): 4547 shutil.copy( 4548 os.path.join(meta_copy_path, file), os.path.join(meta_path, file) 4549 ) 4550 4551 def get_behavior_dictionary(self, episode_name: str) -> Dict: 4552 """ 4553 Get the behavior dictionary for an episode 4554 4555 Parameters 4556 ---------- 4557 episode_name : str 4558 the name of the episode 4559 4560 Returns 4561 ------- 4562 behaviors_dictionary : dict 4563 a dictionary where keys are label indices and values are label names 4564 """ 4565 4566 run = self._episodes().get_runs(episode_name)[0] 4567 return self._episode(run).get_behaviors_dict() 4568 4569 def import_episodes( 4570 self, 4571 episodes_directory: str, 4572 name_map: Dict = None, 4573 repeat_policy: str = "error", 4574 ) -> None: 4575 """ 4576 Import episodes exported with `Project.export_episodes` 4577 4578 Parameters 4579 ---------- 4580 episodes_directory : str 4581 the path to the exported episodes directory 4582 name_map : dict 4583 a name change dictionary for the episodes: keys are old names, values are new names 4584 """ 4585 4586 if name_map is None: 4587 name_map = {} 4588 episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle")) 4589 to_remove = [] 4590 import_string = "Imported episodes: " 4591 for episode_name in episodes.index: 4592 if episode_name in name_map: 4593 import_string += f"{episode_name} " 4594 episode_name = name_map[episode_name] 4595 import_string += f"({episode_name}), " 4596 else: 4597 import_string += f"{episode_name}, " 4598 try: 4599 self._check_episode_validity(episode_name, allow_doublecolon=True) 4600 except ValueError as e: 4601 if str(e).endswith("is already taken!"): 4602 if repeat_policy == "skip": 4603 to_remove.append(episode_name) 4604 elif repeat_policy == "force": 4605 self.remove_episode(episode_name) 4606 elif repeat_policy == "error": 4607 raise ValueError( 4608 f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it" 4609 ) 4610 else: 4611 raise ValueError( 4612 f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' ans 'error']" 4613 ) 4614 episodes = episodes.drop(index=to_remove) 4615 self._episodes().update( 4616 episodes, 4617 name_map=name_map, 4618 force=(repeat_policy == "force"), 4619 data_path=self.data_path, 4620 annotation_path=self.annotation_path, 4621 ) 4622 for episode_name in episodes.index: 4623 if episode_name in name_map: 4624 new_episode_name = name_map[episode_name] 4625 else: 4626 new_episode_name = episode_name 4627 model_dir = os.path.join( 4628 self.project_path, "results", "model", new_episode_name 4629 ) 4630 old_model_dir = os.path.join(episodes_directory, "model", episode_name) 4631 if os.path.exists(model_dir): 4632 shutil.rmtree(model_dir) 4633 os.mkdir(model_dir) 4634 for file in os.listdir(old_model_dir): 4635 shutil.copyfile( 4636 os.path.join(old_model_dir, file), os.path.join(model_dir, file) 4637 ) 4638 log_file = os.path.join( 4639 self.project_path, "results", "logs", f"{new_episode_name}.txt" 4640 ) 4641 old_log_file = os.path.join( 4642 episodes_directory, "logs", f"{episode_name}.txt" 4643 ) 4644 shutil.copyfile(old_log_file, log_file) 4645 print(import_string) 4646 print("\n") 4647 4648 def export_episodes( 4649 self, episode_names: List, output_directory: str, name: str = None 4650 ) -> None: 4651 """ 4652 Save selected episodes as a file that can be imported into another project with `Project.import_episodes` 4653 4654 Parameters 4655 ---------- 4656 episode_names : list 4657 a list of string episode names 4658 output_directory : str 4659 the path to the directory where the episodes will be saved 4660 name : str, optional 4661 the name of the episodes directory (by default `exported_episodes`) 4662 """ 4663 4664 if name is None: 4665 name = "exported_episodes" 4666 if os.path.exists( 4667 os.path.join(output_directory, name + ".zip") 4668 ) or os.path.exists(os.path.join(output_directory, name)): 4669 i = 1 4670 while os.path.exists( 4671 os.path.join(output_directory, name + f"_{i}.zip") 4672 ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")): 4673 i += 1 4674 name = name + f"_{i}" 4675 dest_dir = os.path.join(output_directory, name) 4676 os.mkdir(dest_dir) 4677 os.mkdir(os.path.join(dest_dir, "model")) 4678 os.mkdir(os.path.join(dest_dir, "logs")) 4679 runs = [] 4680 for episode in episode_names: 4681 runs += self._episodes().get_runs(episode) 4682 for run in runs: 4683 shutil.copytree( 4684 os.path.join(self.project_path, "results", "model", run), 4685 os.path.join(dest_dir, "model", run), 4686 ) 4687 shutil.copyfile( 4688 os.path.join(self.project_path, "results", "logs", f"{run}.txt"), 4689 os.path.join(dest_dir, "logs", f"{run}.txt"), 4690 ) 4691 data = self._episodes().get_subset(runs) 4692 data.to_pickle(os.path.join(dest_dir, "episodes.pickle")) 4693 4694 def get_results_table( 4695 self, 4696 episode_names: List, 4697 metrics: List = None, 4698 include_std: bool = False, 4699 classes: List = None, 4700 ): 4701 """ 4702 Genererate a `pandas` dataframe with a summary of episode results 4703 4704 Parameters 4705 ---------- 4706 episode_names : list 4707 a list of names of episodes to include 4708 metrics : list, optional 4709 a list of metric names to include 4710 include_std : bool, default False 4711 if `True`, for episodes with multiple runs the mean and standard deviation will be displayed; 4712 otherwise only mean 4713 classes : list, optional 4714 a list of names of classes to include (by default all are included) 4715 4716 Returns 4717 ------- 4718 results : pd.DataFrame 4719 a table with the results 4720 """ 4721 4722 run_names = [] 4723 for episode in episode_names: 4724 run_names += self._episodes().get_runs(episode) 4725 episodes = self.list_episodes(run_names, print_results=False) 4726 metric_columns = [x for x in episodes.columns if x[0] == "results"] 4727 results_df = pd.DataFrame() 4728 if metrics is not None: 4729 metric_columns = [ 4730 x for x in metric_columns if x[1].split("_")[0] in metrics 4731 ] 4732 for episode in episode_names: 4733 results = [] 4734 metric_set = set() 4735 for run in self._episodes().get_runs(episode): 4736 beh_dict = self.get_behavior_dictionary(run) 4737 res_dict = defaultdict(lambda: {}) 4738 for column in metric_columns: 4739 if np.isnan(episodes.loc[run, column]): 4740 continue 4741 split = column[1].split("_") 4742 if split[-1].isnumeric(): 4743 beh_ind = int(split[-1]) 4744 metric_name = "_".join(split[:-1]) 4745 beh = beh_dict[beh_ind] 4746 else: 4747 beh = "average" 4748 metric_name = column[1] 4749 res_dict[beh][metric_name] = episodes.loc[run, column] 4750 metric_set.add(metric_name) 4751 if "average" not in res_dict: 4752 res_dict["average"] = {} 4753 for metric in metric_set: 4754 if metric not in res_dict["average"]: 4755 arr = [ 4756 res_dict[beh][metric] 4757 for beh in res_dict 4758 if metric in res_dict[beh] 4759 ] 4760 res_dict["average"][metric] = np.mean(arr) 4761 results.append(res_dict) 4762 episode_results = {} 4763 for metric in metric_set: 4764 for beh in results[0].keys(): 4765 if classes is not None and beh not in classes: 4766 continue 4767 arr = [] 4768 for res_dict in results: 4769 if metric in res_dict[beh]: 4770 arr.append(res_dict[beh][metric]) 4771 if len(arr) > 0: 4772 if include_std: 4773 episode_results[ 4774 (beh, f"{episode} {metric} mean") 4775 ] = np.mean(arr) 4776 episode_results[(beh, f"{episode} {metric} std")] = np.std( 4777 arr 4778 ) 4779 else: 4780 episode_results[(beh, f"{episode} {metric}")] = np.mean(arr) 4781 for key, value in episode_results.items(): 4782 results_df.loc[key[0], key[1]] = value 4783 print(f"RESULTS:") 4784 print(results_df) 4785 print("\n") 4786 return results_df 4787 4788 def episode_exists(self, episode_name: str) -> bool: 4789 """ 4790 Check if an episode already exists 4791 4792 Parameters 4793 ---------- 4794 episode_name : str 4795 the episode name 4796 4797 Returns 4798 ------- 4799 exists : bool 4800 `True` if the episode exists 4801 """ 4802 4803 return self._episodes().check_name_validity(episode_name) 4804 4805 def search_exists(self, search_name: str) -> bool: 4806 """ 4807 Check if a search already exists 4808 4809 Parameters 4810 ---------- 4811 search_name : str 4812 the search name 4813 4814 Returns 4815 ------- 4816 exists : bool 4817 `True` if the search exists 4818 """ 4819 4820 return self._searches().check_name_validity(search_name) 4821 4822 def prediction_exists(self, prediction_name: str) -> bool: 4823 """ 4824 Check if a prediction already exists 4825 4826 Parameters 4827 ---------- 4828 prediction_name : str 4829 the prediction name 4830 4831 Returns 4832 ------- 4833 exists : bool 4834 `True` if the prediction exists 4835 """ 4836 4837 return self._predictions().check_name_validity(prediction_name) 4838 4839 @staticmethod 4840 def project_name_available(projects_path: str, project_name: str): 4841 if projects_path is None: 4842 projects_path = os.path.join(str(Path.home()), "DLC2Action") 4843 return not os.path.exists(os.path.join(projects_path, project_name)) 4844 4845 def _update_episode_metrics(self, episode_name: str, metrics: Dict): 4846 """ 4847 Update meta data with evaluation results 4848 """ 4849 4850 self._episodes().update_episode_metrics(episode_name, metrics) 4851 4852 def rename_episode(self, episode_name: str, new_episode_name: str): 4853 shutil.move( 4854 os.path.join(self.project_path, "results", "model", episode_name), 4855 os.path.join(self.project_path, "results", "model", new_episode_name), 4856 ) 4857 shutil.move( 4858 os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"), 4859 os.path.join( 4860 self.project_path, "results", "logs", f"{new_episode_name}.txt" 4861 ), 4862 ) 4863 self._episodes().rename_episode(episode_name, new_episode_name) 4864 4865 4866class _Runner: 4867 """ 4868 A helper class for running hyperparameter searches 4869 """ 4870 4871 def __init__( 4872 self, 4873 search_name, 4874 search_space: Dict, 4875 load_episode: str, 4876 load_epoch: int, 4877 metric: str, 4878 average: int, 4879 task: Union[TaskDispatcher, None], 4880 remove_saved_features: bool, 4881 project: Project, 4882 ): 4883 """ 4884 Parameters 4885 ---------- 4886 search_space : dict 4887 a dictionary representing the search space; of this general structure: 4888 {'group/param_name': ('float/int/float_log/int_log', start, end), 4889 'group/param_name': ('categorical', [choices])}, e.g. 4890 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 4891 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])} 4892 load_episode : str 4893 the name of the episode to load the model from 4894 load_epoch : int 4895 the epoch to load the model from (if not provided, the last checkpoint is used) 4896 metric : str 4897 the metric to maximize/minimize (see direction) 4898 average : int 4899 the number of epochs to average the metric; if 0, the last value is taken 4900 remove_saved_features : bool 4901 if `True`, the old datasets will be deleted when data parameters change 4902 project : Project 4903 the parent `Project` instance 4904 """ 4905 4906 self.search_space = search_space 4907 self.load_episode = load_episode 4908 self.load_epoch = load_epoch 4909 self.metric = metric 4910 self.average = average 4911 self.feature_save_path = None 4912 self.remove_saved_featuress = remove_saved_features 4913 self.save_stores = project._save_stores 4914 self.remove_datasets = project.remove_saved_features 4915 self.task = task 4916 self.search_name = search_name 4917 self.update = project._update 4918 self.remove_episode = project.remove_episode 4919 self.fill = project._fill 4920 4921 def clean(self): 4922 """ 4923 Remove datasets if needed 4924 """ 4925 4926 if self.remove_saved_featuress: 4927 self.remove_datasets([os.path.basename(self.feature_save_path)]) 4928 4929 def run(self, trial, parameters): 4930 """ 4931 Make a trial run 4932 """ 4933 4934 params = deepcopy(parameters) 4935 param_update = defaultdict( 4936 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: {}))) 4937 ) 4938 for full_name, space in self.search_space.items(): 4939 group, param_name = full_name.split("/")[0], "/".join( 4940 full_name.split("/")[1:] 4941 ) 4942 log = space[0][-3:] == "log" 4943 if space[0].startswith("int"): 4944 value = trial.suggest_int(full_name, space[1], space[2], log=log) 4945 elif space[0].startswith("float"): 4946 value = trial.suggest_float(full_name, space[1], space[2], log=log) 4947 elif space[0] == "categorical": 4948 value = trial.suggest_categorical(full_name, space[1]) 4949 else: 4950 raise ValueError( 4951 "The search space has to be formatted as either " 4952 '("float"/"int"/"float_log"/"int_log", start, end) ' 4953 f'or ("categorical", [choices]); got {space} for {group}/{param_name}' 4954 ) 4955 if len(param_name.split("/")) == 1: 4956 param_update[group][param_name] = value 4957 else: 4958 pars = param_name.split("/") 4959 pars = [int(x) if x.isnumeric() else x for x in pars] 4960 if len(pars) == 2: 4961 param_update[group][pars[0]][pars[1]] = value 4962 elif len(pars) == 3: 4963 param_update[group][pars[0]][pars[1]][pars[2]] = value 4964 elif len(pars) == 4: 4965 param_update[group][pars[0]][pars[1]][pars[2]][pars[3]] = value 4966 params = self.update(params, param_update) 4967 self.remove_episode(f"_{self.search_name}") 4968 params = self.fill( 4969 params, 4970 f"_{self.search_name}", 4971 self.load_episode, 4972 load_epoch=self.load_epoch, 4973 only_load_model=True, 4974 ) 4975 if self.feature_save_path != params["data"]["feature_save_path"]: 4976 if self.feature_save_path is not None: 4977 self.clean() 4978 self.feature_save_path = params["data"]["feature_save_path"] 4979 self.save_stores(params) 4980 if self.task is None: 4981 self.task = TaskDispatcher(deepcopy(params)) 4982 else: 4983 self.task.update_task(params) 4984 4985 _, metrics_log = self.task.train(trial, self.metric) 4986 metric_values = metrics_log["val"][self.metric] 4987 if self.average > 0: 4988 value = np.mean(sorted(metric_values)[-self.average :]) 4989 else: 4990 value = metric_values[-1] 4991 return value
51class Project: 52 """ 53 A class to create and maintain the project files + keep track of experiments 54 """ 55 56 def __init__( 57 self, 58 name: str, 59 data_type: str = None, 60 annotation_type: str = "none", 61 projects_path: str = None, 62 data_path: Union[str, List] = None, 63 annotation_path: Union[str, List] = None, 64 copy: bool = False, 65 ) -> None: 66 """ 67 Parameters 68 ---------- 69 name : str 70 name of the project 71 data_type : str, optional 72 data type (run Project.data_types() to see available options; has to be provided if the project is being 73 created) 74 annotation_type : str, default 'none' 75 annotation type (run Project.annotation_types() to see available options) 76 projects_path : str, optional 77 path to the projects folder (is filled with ~/DLC2Action by default) 78 data_path : str, optional 79 path to the folder containing input files for the project (has to be provided if the project is being 80 created) 81 annotation_path : str, optional 82 path to the folder containing annotation files for the project 83 copy : bool, default False 84 if True, the files from annotation_path and data_path will be copied to the projects folder; 85 otherwise they will be moved 86 """ 87 88 if projects_path is None: 89 projects_path = os.path.join(str(Path.home()), "DLC2Action") 90 if not os.path.exists(projects_path): 91 os.mkdir(projects_path) 92 self.project_path = os.path.join(projects_path, name) 93 self.name = name 94 self.data_type = data_type 95 self.annotation_type = annotation_type 96 self.data_path = data_path 97 self.annotation_path = annotation_path 98 if not os.path.exists(self.project_path): 99 if data_type is None: 100 raise ValueError( 101 "The data_type parameter is necessary when creating a new project!" 102 ) 103 self._initialize_project( 104 data_type, annotation_type, data_path, annotation_path, copy 105 ) 106 else: 107 self.annotation_type, self.data_type = self._read_types() 108 if data_type != self.data_type and data_type is not None: 109 raise ValueError( 110 f"The project has already been initialized with data_type={self.data_type}!" 111 ) 112 if annotation_type != self.annotation_type and annotation_type != "none": 113 raise ValueError( 114 f"The project has already been initialized with annotation_type={self.annotation_type}!" 115 ) 116 self.annotation_path, data_path = self._read_paths() 117 if self.data_path is None: 118 self.data_path = data_path 119 # if data_path != self.data_path and data_path is not None: 120 # raise ValueError( 121 # f"The project has already been initialized with data_path={self.data_path}!" 122 # ) 123 if annotation_path != self.annotation_path and annotation_path is not None: 124 raise ValueError( 125 f"The project has already been initialized with annotation_path={self.annotation_path}!" 126 ) 127 self._update_configs() 128 129 def _aggregate_predictions( 130 self, 131 prediction_name: str, 132 episode_names: List, 133 load_epochs: List = None, 134 parameters_update: Dict = None, 135 data_path: str = None, 136 file_paths: Set = None, 137 mode: str = "all", 138 augment_n: int = 0, 139 evaluate: bool = False, 140 task: TaskDispatcher = None, 141 embedding: bool = False, 142 ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]: 143 """ 144 Generate a prediction 145 """ 146 147 if load_epochs is None: 148 load_epochs = [None for _ in episode_names] 149 prediction = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) 150 cnt = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) 151 behs = set(self.get_behavior_dictionary(episode_names[0]).values()) 152 if not all( 153 [ 154 set(self.get_behavior_dictionary(x).values()) == behs 155 for x in episode_names 156 ] 157 ): 158 raise ValueError(f"The behavior sets are different in {episode_names}") 159 behaviors = set() 160 for i, episode_name in enumerate(episode_names): 161 task, parameters, data_mode, new_pred, _ = self._make_prediction( 162 prediction_name, 163 episode_names=[episode_name], 164 load_epochs=[load_epochs[i]], 165 parameters_update=parameters_update, 166 data_path=data_path, 167 file_paths=file_paths, 168 mode=mode, 169 augment_n=augment_n, 170 evaluate=evaluate, 171 task=task, 172 embedding=embedding, 173 ) 174 new_pred = task.dataset(data_mode).generate_full_length_prediction(new_pred) 175 beh_dict = task.behaviors_dict() 176 for video_id, video_values in new_pred.items(): 177 for clip_id, clip_prediction in video_values.items(): 178 for beh_i in range(clip_prediction.shape[0]): 179 prediction[video_id][clip_id][ 180 beh_dict[beh_i] 181 ] += clip_prediction[beh_i, :].unsqueeze(0) 182 cnt[video_id][clip_id][beh_dict[beh_i]] += 1 183 behaviors.add(beh_dict[beh_i]) 184 output = defaultdict(lambda: {}) 185 # behaviors = sorted(behaviors) 186 behavior_indices = sorted( 187 [x for x in task.behaviors_dict().keys() if x != -100] 188 ) 189 behaviors = [task.behaviors_dict()[key] for key in behavior_indices] 190 for video_id, video_values in prediction.items(): 191 for clip_id, clip_values in video_values.items(): 192 pred = torch.cat( 193 [ 194 clip_values[beh] / cnt[video_id][clip_id][beh] 195 for beh in behaviors 196 ], 197 0, 198 ) 199 output[video_id][clip_id] = pred 200 return task, parameters, data_mode, dict(output), None 201 202 def _make_prediction( 203 self, 204 prediction_name: str, 205 episode_names: List, 206 load_epochs: List = None, 207 parameters_update: Dict = None, 208 data_path: str = None, 209 file_paths: Set = None, 210 mode: str = "all", 211 augment_n: int = 0, 212 evaluate: bool = False, 213 task: TaskDispatcher = None, 214 embedding: bool = False, 215 ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]: 216 """ 217 Generate a prediction 218 """ 219 220 names = [] 221 epochs = [] 222 if load_epochs is None: 223 load_epochs = [None for _ in episode_names] 224 if len(load_epochs) != len(episode_names): 225 raise ValueError( 226 f"The length of load_epochs and the length of episode_names should be the same!" 227 ) 228 for i, episode_name in enumerate(episode_names): 229 names += self._episodes().get_runs(episode_name) 230 epochs.append(load_epochs[i]) 231 if len(names) == 0: 232 warnings.warn(f"None of the episodes {episode_names} exist!") 233 names = [None] 234 episodes = self._episodes() 235 lengths = [ 236 episodes.load_parameters(name)["general"]["len_segment"] for name in names 237 ] 238 overlaps = [ 239 episodes.load_parameters(name)["general"]["overlap"] for name in names 240 ] 241 if not all([x == lengths[0] for x in lengths]): 242 raise ValueError(f"Episodes {episode_names} have different segment lengths") 243 if not all([x == overlaps[0] for x in overlaps]): 244 raise ValueError(f"Episodes {episode_names} have different overlaps") 245 load_epochs = epochs 246 prediction = None 247 decision_thresholds = None 248 time_total = 0 249 behavior_dicts = [ 250 self.get_behavior_dictionary(episode_name) for episode_name in names 251 ] 252 if not all( 253 [ 254 set(d.values()) == set(behavior_dicts[0].values()) 255 for d in behavior_dicts[1:] 256 ] 257 ): 258 raise ValueError( 259 f"Episodes {episode_names} have different sets of behaviors!" 260 ) 261 behavior_indices = [x for x in behavior_dicts[0].keys() if x != -100] 262 behaviors = [behavior_dicts[0][i] for i in behavior_indices] 263 cnt = defaultdict(lambda: 0) 264 behavior_probs = defaultdict(lambda: 0) 265 for episode_name, load_epoch, behavior_dict in zip( 266 names, load_epochs, behavior_dicts 267 ): 268 print(f"episode {episode_name}") 269 task, parameters, data_mode = self._make_task_prediction( 270 prediction_name=prediction_name, 271 load_episode=episode_name, 272 parameters_update=parameters_update, 273 load_epoch=load_epoch, 274 data_path=data_path, 275 mode=mode, 276 file_paths=file_paths, 277 task=task, 278 decision_thresholds=decision_thresholds, 279 ) 280 behavior_indices_cur = [x for x in behavior_dict.keys() if x != -100] 281 behaviors_cur = [behavior_dict[i] for i in behavior_indices_cur] 282 # data_mode = "train" if mode == "all" else mode 283 time_start = time.time() 284 new_pred = task.predict( 285 data_mode, 286 raw_output=True, 287 apply_primary_function=True, 288 augment_n=augment_n, 289 embedding=embedding, 290 ) 291 for j, beh in enumerate(behaviors_cur): 292 cnt[beh] += 1 293 behavior_probs[beh] += new_pred[:, j, :].unsqueeze(1) 294 # indices = [ 295 # behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1]) 296 # ] 297 # new_pred = new_pred[:, indices, :] 298 time_end = time.time() 299 time_total += time_end - time_start 300 if evaluate: 301 _, metrics = task.evaluate_prediction(new_pred, data=data_mode) 302 if mode == "val": 303 self._update_episode_metrics(episode_name, metrics) 304 # if prediction is None: 305 # prediction = new_pred 306 # else: 307 # prediction += new_pred 308 print("\n") 309 prediction = torch.cat([behavior_probs[beh] / cnt[beh] for beh in behaviors], 1) 310 hours = int(time_total // 3600) 311 time_total -= hours * 3600 312 minutes = int(time_total // 60) 313 time_total -= minutes * 60 314 seconds = int(time_total) 315 inference_time = f"{hours}:{minutes:02}:{seconds:02}" 316 # prediction /= len(names) 317 return task, parameters, data_mode, prediction, inference_time 318 319 def _make_task_prediction( 320 self, 321 prediction_name: str, 322 load_episode: str = None, 323 parameters_update: Dict = None, 324 load_epoch: int = None, 325 data_path: str = None, 326 mode: str = "val", 327 file_paths: Set = None, 328 decision_thresholds: List = None, 329 task: TaskDispatcher = None, 330 ) -> Tuple[TaskDispatcher, Dict, str]: 331 """ 332 Make a `TaskDispatcher` object that will be used to generate a prediction 333 """ 334 335 if parameters_update is None: 336 parameters_update = {} 337 parameters_update_second = {} 338 if mode == "all" or data_path is not None or file_paths is not None: 339 parameters_update_second["training"] = { 340 "val_frac": 0, 341 "test_frac": 0, 342 "partition_method": "random", 343 "save_split": False, 344 "split_path": None, 345 } 346 mode = "train" 347 if decision_thresholds is not None: 348 if ( 349 len(decision_thresholds) 350 == self._episode(load_episode).get_num_classes() 351 ): 352 parameters_update_second["general"] = { 353 "threshold_value": decision_thresholds 354 } 355 else: 356 raise ValueError( 357 f"The length of the decision thresholds {decision_thresholds} " 358 f"must be equal to the length of the behaviors dictionary " 359 f"{self._episode(load_episode).get_behaviors_dict()}" 360 ) 361 data_param_update = {} 362 if data_path is not None: 363 data_param_update = {"data_path": data_path} 364 if file_paths is not None: 365 data_param_update = {"data_path": None, "file_paths": file_paths} 366 parameters_update = self._update(parameters_update, {"data": data_param_update}) 367 if data_path is not None or file_paths is not None: 368 general_update = { 369 "annotation_type": "none", 370 "only_load_annotated": False, 371 } 372 else: 373 general_update = {} 374 parameters_update = self._update(parameters_update, {"general": general_update}) 375 task, parameters = self._make_task( 376 episode_name=prediction_name, 377 load_episode=load_episode, 378 parameters_update=parameters_update, 379 parameters_update_second=parameters_update_second, 380 load_epoch=load_epoch, 381 purpose="prediction", 382 task=task, 383 behaviors=self.get_behavior_dictionary(load_episode), 384 ) 385 # if data_path is not None or file_paths is not None: 386 # print('SETTING') 387 # task.set_behaviors(self.get_behavior_dictionary(load_episode)) 388 if mode is None: 389 if task.exists("test"): 390 mode = "test" 391 elif task.exists("val"): 392 mode = "val" 393 else: 394 mode = "train" 395 return task, parameters, mode 396 397 def _make_task_training( 398 self, 399 episode_name: str, 400 load_episode: str = None, 401 parameters_update: Dict = None, 402 load_epoch: int = None, 403 load_search: str = None, 404 load_parameters: list = None, 405 round_to_binary: list = None, 406 load_strict: bool = True, 407 continuing: bool = False, 408 task: TaskDispatcher = None, 409 mask_name: str = None, 410 throwaway: bool = False, 411 ) -> Tuple[TaskDispatcher, Dict, str]: 412 """ 413 Make a `TaskDispatcher` object that will be used to generate a prediction 414 """ 415 416 if parameters_update is None: 417 parameters_update = {} 418 if continuing: 419 purpose = "continuing" 420 else: 421 purpose = "training" 422 if mask_name is not None: 423 mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle") 424 parameters_update_second = {"data": {"real_lens": mask_name}} 425 if throwaway: 426 parameters_update = self._update( 427 parameters_update, {"training": {"normalize": False, "device": "cpu"}} 428 ) 429 return self._make_task( 430 episode_name, 431 load_episode, 432 parameters_update, 433 parameters_update_second, 434 load_epoch, 435 load_search, 436 load_parameters, 437 round_to_binary, 438 purpose, 439 task, 440 load_strict=load_strict, 441 ) 442 443 def _make_parameters( 444 self, 445 episode_name: str, 446 load_episode: str = None, 447 parameters_update: Dict = None, 448 parameters_update_second: Dict = None, 449 load_epoch: int = None, 450 load_search: str = None, 451 load_parameters: list = None, 452 round_to_binary: list = None, 453 purpose: str = "train", 454 load_strict: bool = True, 455 ): 456 """ 457 Construct a parameters dictionary 458 """ 459 460 if parameters_update is None: 461 parameters_update = {} 462 pars_update = deepcopy(parameters_update) 463 if parameters_update_second is None: 464 parameters_update_second = {} 465 if purpose == "prediction" and "model" in pars_update.keys(): 466 raise ValueError("Cannot change model parameters after training!") 467 if purpose in ["continuing", "prediction"] and load_episode is not None: 468 read_parameters = self._read_parameters() 469 parameters = self._episodes().load_parameters(load_episode) 470 parameters["metrics"] = self._update( 471 read_parameters["metrics"], parameters["metrics"] 472 ) 473 parameters["ssl"] = self._update( 474 read_parameters["ssl"], parameters.get("ssl", {}) 475 ) 476 else: 477 parameters = self._read_parameters() 478 if "model" in pars_update: 479 model_params = pars_update.pop("model") 480 else: 481 model_params = None 482 if "features" in pars_update: 483 feat_params = pars_update.pop("features") 484 else: 485 feat_params = None 486 if "augmentations" in pars_update: 487 aug_params = pars_update.pop("augmentations") 488 else: 489 aug_params = None 490 parameters = self._update(parameters, pars_update) 491 if pars_update.get("general", {}).get("model_name") is not None: 492 model_name = parameters["general"]["model_name"] 493 parameters["model"] = self._open_yaml( 494 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 495 ) 496 if pars_update.get("general", {}).get("feature_extraction") is not None: 497 feat_name = parameters["general"]["feature_extraction"] 498 parameters["features"] = self._open_yaml( 499 os.path.join( 500 self.project_path, "config", "features", f"{feat_name}.yaml" 501 ) 502 ) 503 aug_name = options.extractor_to_transformer[ 504 parameters["general"]["feature_extraction"] 505 ] 506 parameters["augmentations"] = self._open_yaml( 507 os.path.join( 508 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 509 ) 510 ) 511 if model_params is not None: 512 parameters["model"] = self._update(parameters["model"], model_params) 513 if feat_params is not None: 514 parameters["features"] = self._update(parameters["features"], feat_params) 515 if aug_params is not None: 516 parameters["augmentations"] = self._update( 517 parameters["augmentations"], aug_params 518 ) 519 if load_search is not None: 520 parameters = self._update_with_search( 521 parameters, load_search, load_parameters, round_to_binary 522 ) 523 parameters = self._fill( 524 parameters, 525 episode_name, 526 load_episode, 527 load_epoch=load_epoch, 528 load_strict=load_strict, 529 only_load_model=(purpose != "continuing"), 530 continuing=(purpose in ["prediction", "continuing"]), 531 enforce_split_parameters=(purpose == "prediction"), 532 ) 533 parameters = self._update(parameters, parameters_update_second) 534 return parameters 535 536 def _make_task( 537 self, 538 episode_name: str, 539 load_episode: str = None, 540 parameters_update: Dict = None, 541 parameters_update_second: Dict = None, 542 load_epoch: int = None, 543 load_search: str = None, 544 load_parameters: list = None, 545 round_to_binary: list = None, 546 purpose: str = "train", 547 task: TaskDispatcher = None, 548 load_strict: bool = True, 549 behaviors: Dict = None, 550 ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]: 551 """ 552 Make a `TaskDispatcher` object 553 554 The task parameters are read from the config files and then updated with the 555 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 556 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 557 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 558 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 559 data parameters are used. 560 561 Parameters 562 ---------- 563 episode_name : str 564 the name of the episode 565 load_episode : str, optional 566 the (previously run) episode name to load the model from 567 parameters_update : dict, optional 568 the dictionary used to update the parameters from the config 569 parameters_update_second : dict, optional 570 the dictionary used to update the parameters after the automatic fill-out 571 load_epoch : int, optional 572 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 573 load_search : str, optional 574 the hyperparameter search result to load 575 load_parameters : list, optional 576 a list of string names of the parameters to load from load_search (if not provided, all parameters 577 are loaded) 578 round_to_binary : list, optional 579 a list of string names of the loaded parameters that should be rounded to the nearest power of two 580 purpose : {"train", "continuing", "prediction"} 581 the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing 582 the training of an interrupted episode, `"prediction"` for generating a prediction) 583 task : TaskDispatcher, optional 584 a pre-existing task; if provided, the method will update the task instead of creating a new one 585 (this might save time, mainly on dataset loading) 586 587 Returns 588 ------- 589 task : TaskDispatcher 590 the `TaskDispatcher` instance 591 parameters : dict 592 the parameters dictionary that describes the task 593 """ 594 595 parameters = self._make_parameters( 596 episode_name, 597 load_episode, 598 parameters_update, 599 parameters_update_second, 600 load_epoch, 601 load_search, 602 load_parameters, 603 round_to_binary, 604 purpose, 605 load_strict=load_strict, 606 ) 607 if parameters["data"].get("annotation_type", "none") == "none": 608 parameters = self._update( 609 parameters, {"data": {"behavior_dictionary": behaviors}} 610 ) 611 if task is None: 612 task = TaskDispatcher(parameters) 613 else: 614 task.update_task(parameters) 615 self._save_stores(parameters) 616 return task, parameters 617 618 def run_episode( 619 self, 620 episode_name: str, 621 load_episode: str = None, 622 parameters_update: Dict = None, 623 task: TaskDispatcher = None, 624 load_epoch: int = None, 625 load_search: str = None, 626 load_parameters: list = None, 627 round_to_binary: list = None, 628 load_strict: bool = True, 629 n_seeds: int = 1, 630 force: bool = False, 631 suppress_name_check: bool = False, 632 remove_saved_features: bool = False, 633 mask_name: str = None, 634 autostop_metric: str = None, 635 autostop_interval: int = 50, 636 autostop_threshold: float = 0.001, 637 loading_bar: bool = False, 638 trial: Tuple = None, 639 ) -> TaskDispatcher: 640 """ 641 Run an episode 642 643 The task parameters are read from the config files and then updated with the 644 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 645 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 646 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 647 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 648 data parameters are used. 649 650 You can use the autostop parameters to finish training when the parameters are not improving. It will be 651 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 652 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 653 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 654 655 Parameters 656 ---------- 657 episode_name : str 658 the episode name 659 load_episode : str, optional 660 the (previously run) episode name to load the model from; if the episode has multiple runs, 661 the new episode will have the same number of runs, each starting with one of the pre-trained models 662 parameters_update : dict, optional 663 the dictionary used to update the parameters from the config files 664 task : TaskDispatcher, optional 665 a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating 666 a new instance) 667 load_epoch : int, optional 668 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 669 load_search : str, optional 670 the hyperparameter search result to load 671 load_parameters : list, optional 672 a list of string names of the parameters to load from load_search (if not provided, all parameters 673 are loaded) 674 round_to_binary : list, optional 675 a list of string names of the loaded parameters that should be rounded to the nearest power of two 676 load_strict : bool, default True 677 if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and 678 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` 679 n_seeds : int, default 1 680 the number of runs to perform with different random seeds; if `n_seeds > 1`, the episodes will be named 681 `episode_name::seed_index`, e.g. `test_episode::0` and `test_episode::1` 682 force : bool, default False 683 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 684 suppress_name_check : bool, default False 685 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 686 why they are usually forbidden) 687 remove_saved_features : bool, default False 688 if `True`, the dataset will be deleted after training 689 mask_name : str, optional 690 the name of the real_lens to apply 691 autostop_interval : int, default 50 692 the number of epochs to average the autostop metric over 693 autostop_threshold : float, default 0.001 694 the autostop difference threshold 695 autostop_metric : str, optional 696 the autostop metric (can be any one of the tracked metrics of `'loss'`) 697 """ 698 699 if type(n_seeds) is not int or n_seeds < 1: 700 raise ValueError( 701 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 702 ) 703 if n_seeds > 1 and mask_name is not None: 704 raise ValueError("Cannot apply a real_lens with n_seeds > 1") 705 self._check_episode_validity( 706 episode_name, allow_doublecolon=suppress_name_check, force=force 707 ) 708 load_runs = self._episodes().get_runs(load_episode) 709 if len(load_runs) > 1: 710 task = self.run_episodes( 711 episode_names=[ 712 f'{episode_name}::{run.split("::")[-1]}' for run in load_runs 713 ], 714 load_episodes=load_runs, 715 parameters_updates=[parameters_update for _ in load_runs], 716 load_epochs=[load_epoch for _ in load_runs], 717 load_searches=[load_search for _ in load_runs], 718 load_parameters=[load_parameters for _ in load_runs], 719 round_to_binary=[round_to_binary for _ in load_runs], 720 load_strict=[load_strict for _ in load_runs], 721 suppress_name_check=True, 722 force=force, 723 remove_saved_features=False, 724 ) 725 if remove_saved_features: 726 self._remove_stores( 727 { 728 "general": task.general_parameters, 729 "data": task.data_parameters, 730 "features": task.feature_parameters, 731 } 732 ) 733 if n_seeds > 1: 734 warnings.warn( 735 f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs" 736 ) 737 elif n_seeds > 1: 738 self.run_episodes( 739 episode_names=[f"{episode_name}::{i}" for i in range(n_seeds)], 740 load_episodes=[load_episode for _ in range(n_seeds)], 741 parameters_updates=[parameters_update for _ in range(n_seeds)], 742 load_epochs=[load_epoch for _ in range(n_seeds)], 743 load_searches=[load_search for _ in range(n_seeds)], 744 load_parameters=[load_parameters for _ in range(n_seeds)], 745 round_to_binary=[round_to_binary for _ in range(n_seeds)], 746 load_strict=[load_strict for _ in range(n_seeds)], 747 suppress_name_check=True, 748 force=force, 749 remove_saved_features=remove_saved_features, 750 ) 751 else: 752 print(f"TRAINING {episode_name}") 753 try: 754 task, parameters = self._make_task_training( 755 episode_name, 756 load_episode, 757 parameters_update, 758 load_epoch, 759 load_search, 760 load_parameters, 761 round_to_binary, 762 continuing=False, 763 task=task, 764 mask_name=mask_name, 765 load_strict=load_strict, 766 ) 767 self._save_episode( 768 episode_name, 769 parameters, 770 task.behaviors_dict(), 771 norm_stats=task.get_normalization_stats(), 772 ) 773 time_start = time.time() 774 if trial is not None: 775 trial, metric = trial 776 else: 777 trial, metric = None, None 778 logs = task.train( 779 autostop_metric=autostop_metric, 780 autostop_interval=autostop_interval, 781 autostop_threshold=autostop_threshold, 782 loading_bar=loading_bar, 783 trial=trial, 784 optimized_metric=metric, 785 ) 786 time_end = time.time() 787 time_total = time_end - time_start 788 hours = int(time_total // 3600) 789 time_total -= hours * 3600 790 minutes = int(time_total // 60) 791 time_total -= minutes * 60 792 seconds = int(time_total) 793 training_time = f"{hours}:{minutes:02}:{seconds:02}" 794 self._update_episode_results(episode_name, logs, training_time) 795 if remove_saved_features: 796 self._remove_stores(parameters) 797 print("\n") 798 return task 799 800 except Exception as e: 801 if isinstance(e, optuna.exceptions.TrialPruned): 802 raise e 803 else: 804 # if str(e) != f"The {episode_name} episode name is already in use!": 805 # self.remove_episode(episode_name) 806 raise RuntimeError(f"Episode {episode_name} could not run") 807 808 def run_episodes( 809 self, 810 episode_names: List, 811 load_episodes: List = None, 812 parameters_updates: List = None, 813 load_epochs: List = None, 814 load_searches: List = None, 815 load_parameters: List = None, 816 round_to_binary: List = None, 817 load_strict: List = None, 818 force: bool = False, 819 suppress_name_check: bool = False, 820 remove_saved_features: bool = False, 821 ) -> TaskDispatcher: 822 """ 823 Run multiple episodes in sequence (and re-use previously loaded information) 824 825 For each episode, the task parameters are read from the config files and then updated with the 826 parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the 827 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 828 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 829 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 830 data parameters are used. 831 832 Parameters 833 ---------- 834 episode_names : list 835 a list of strings of episode names 836 load_episodes : list, optional 837 a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, 838 the new episode will have the same number of runs, each starting with one of the pre-trained models 839 parameters_updates : list, optional 840 a list of dictionaries used to update the parameters from the config 841 load_epochs : list, optional 842 a list of integers used to specify the epoch to load (if load_episodes is not None) 843 load_searches : list, optional 844 a list of strings of hyperparameter search results to load 845 load_parameters : list, optional 846 a list of lists of string names of the parameters to load from the searches 847 round_to_binary : list, optional 848 a list of string names of the loaded parameters that should be rounded to the nearest power of two 849 load_strict : list, optional 850 a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from 851 the corresponding episode and differences in parameter name lists and 852 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for 853 every episode) 854 force : bool, default False 855 if `True` and an episode name is already taken, it will be overwritten (use with caution!) 856 suppress_name_check : bool, default False 857 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 858 why they are usually forbidden) 859 remove_saved_features : bool, default False 860 if `True`, the dataset will be deleted after training 861 """ 862 863 task = None 864 if load_searches is None: 865 load_searches = [None for _ in episode_names] 866 if load_episodes is None: 867 load_episodes = [None for _ in episode_names] 868 if parameters_updates is None: 869 parameters_updates = [None for _ in episode_names] 870 if load_parameters is None: 871 load_parameters = [None for _ in episode_names] 872 if load_epochs is None: 873 load_epochs = [None for _ in episode_names] 874 if load_strict is None: 875 load_strict = [True for _ in episode_names] 876 for ( 877 parameters_update, 878 episode_name, 879 load_episode, 880 load_epoch, 881 load_search, 882 load_parameters_list, 883 load_strict_value, 884 ) in zip( 885 parameters_updates, 886 episode_names, 887 load_episodes, 888 load_epochs, 889 load_searches, 890 load_parameters, 891 load_strict, 892 ): 893 task = self.run_episode( 894 episode_name, 895 load_episode, 896 parameters_update, 897 task, 898 load_epoch, 899 load_search, 900 load_parameters_list, 901 round_to_binary, 902 load_strict_value, 903 suppress_name_check=suppress_name_check, 904 force=force, 905 remove_saved_features=remove_saved_features, 906 ) 907 return task 908 909 def continue_episode( 910 self, 911 episode_name: str, 912 num_epochs: int = None, 913 task: TaskDispatcher = None, 914 n_seeds: int = 1, 915 remove_saved_features: bool = False, 916 device: str = "cuda", 917 num_cpus: int = None, 918 ) -> TaskDispatcher: 919 """ 920 Load an older episode and continue running from the latest checkpoint 921 922 All parameters as well as the model and optimizer state dictionaries are loaded from the episode. 923 924 Parameters 925 ---------- 926 episode_name : str 927 the name of the episode to continue 928 num_epochs : int, optional 929 the new number of epochs 930 task : TaskDispatcher, optional 931 a pre-existing task; if provided, the method will update the task instead of creating a new one 932 (this might save time, mainly on dataset loading) 933 result_average_interval : int, default 5 934 the metric are averaged over the last result_average_interval to be stored in the episodes meta file 935 and displayed by list_episodes() function (the full log is still always available) 936 n_seeds : int, default 1 937 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name::run_index`, e.g. 938 `test_episode::0` and `test_episode::1` 939 remove_saved_features : bool, default False 940 if `True`, pre-computed features will be deleted after the run 941 device : str, default "cuda" 942 the torch device to use 943 """ 944 945 runs = self._episodes().get_runs(episode_name) 946 for run in runs: 947 print(f"TRAINING {run}") 948 if num_epochs is None and not self._episode(run).unfinished(): 949 continue 950 parameters_update = { 951 "training": { 952 "num_epochs": num_epochs, 953 "device": device, 954 }, 955 "general": {"num_cpus": num_cpus}, 956 } 957 task, parameters = self._make_task_training( 958 run, 959 load_episode=run, 960 parameters_update=parameters_update, 961 continuing=True, 962 task=task, 963 ) 964 time_start = time.time() 965 logs = task.train() 966 time_end = time.time() 967 old_time = self._training_time(run) 968 if not np.isnan(old_time): 969 time_end += old_time 970 time_total = time_end - time_start 971 hours = int(time_total // 3600) 972 time_total -= hours * 3600 973 minutes = int(time_total // 60) 974 time_total -= minutes * 60 975 seconds = int(time_total) 976 training_time = f"{hours}:{minutes:02}:{seconds:02}" 977 else: 978 training_time = np.nan 979 self._save_episode( 980 run, 981 parameters, 982 task.behaviors_dict(), 983 suppress_validation=True, 984 training_time=training_time, 985 norm_stats=task.get_normalization_stats(), 986 ) 987 self._update_episode_results(run, logs) 988 print("\n") 989 if len(runs) < n_seeds: 990 for i in range(len(runs), n_seeds): 991 self.run_episode( 992 f"{episode_name}::{i}", 993 parameters_update=self._episodes().load_parameters(runs[0]), 994 task=task, 995 suppress_name_check=True, 996 ) 997 if remove_saved_features: 998 self._remove_stores(parameters) 999 return task 1000 1001 def run_default_hyperparameter_search( 1002 self, 1003 search_name: str, 1004 model_name: str = None, 1005 metric: str = "f1", 1006 best_n: int = 3, 1007 direction: str = "maximize", 1008 load_episode: str = None, 1009 load_epoch: int = None, 1010 load_strict: bool = True, 1011 prune: bool = True, 1012 force: bool = False, 1013 remove_saved_features: bool = False, 1014 overlap: float = 0, 1015 num_epochs: int = 50, 1016 test_frac: float = 0, 1017 n_trials=150, 1018 device: str = None, 1019 ): 1020 """ 1021 Run an optuna hyperparameter search with default parameters for a model 1022 1023 For the vast majority of cases, optimizing the default parameters should be enough. 1024 Check out `dlc2action.options.model_hyperparameters` for the lists of parameters. 1025 There are also options to set overlap, test fraction and number of epochs parameters for the search without 1026 modifying the project config files. However, if you want something more complex, look into 1027 `Project.run_hyperparameter_search`. 1028 1029 The task parameters are read from the config files and updated with the parameters_update dictionary. 1030 The model can be either initialized from scratch or loaded from a previously run episode. 1031 For each trial, the objective metric is averaged over a few best epochs. 1032 1033 Parameters 1034 ---------- 1035 search_name : str 1036 the name of the search to store it in the meta files and load in run_episode 1037 model_name : str, optional 1038 the name of the model (by default loaded from the project settings, see `project.help('models')` for options) 1039 metric : str, default f1 1040 the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to 1041 `"none"` in the config files, it will be reset to `"macro"` for the search; see `project.help('metrics')` for options 1042 n_trials : int, default 20 1043 the number of optimization trials to run 1044 best_n : int, default 1 1045 the number of epochs to average the metric; if 0, the last value is taken 1046 parameters_update : dict, optional 1047 the parameters update dictionary 1048 direction : {'maximize', 'minimize'} 1049 optimization direction 1050 load_episode : str, optional 1051 the name of the episode to load the model from 1052 load_epoch : int, optional 1053 the epoch to load the model from (if not provided, the last checkpoint is used) 1054 prune : bool, default False 1055 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1056 (with optuna HyperBand pruner) 1057 force : bool, default False 1058 if `True`, existing searches with the same name will be overwritten 1059 remove_saved_features : bool, default False 1060 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1061 device : str, optional 1062 cuda:{i} or cpu, if not given it is read from the default parameters 1063 1064 Returns 1065 ------- 1066 dict 1067 a dictionary of best parameters 1068 """ 1069 1070 if model_name is None: 1071 model_name = self._read_parameters()["general"]["model_name"] 1072 if model_name not in options.model_hyperparameters: 1073 raise ValueError( 1074 f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()" 1075 ) 1076 pars = { 1077 "general": { 1078 "overlap": overlap, 1079 "model_name": model_name, 1080 "metric_functions": {metric}, 1081 }, 1082 "training": {"num_epochs": num_epochs}, 1083 } 1084 if test_frac is not None: 1085 pars["training"]["test_frac"] = test_frac 1086 if not metric.split("_")[-1].isnumeric(): 1087 project_pars = self._read_parameters() 1088 if project_pars["metrics"][metric].get("average") == "none": 1089 pars["metrics"] = {metric: {"average": "macro"}} 1090 if device is not None: 1091 pars["training"]["device"] = device 1092 return self.run_hyperparameter_search( 1093 search_name=search_name, 1094 search_space=options.model_hyperparameters[model_name], 1095 metric=metric, 1096 n_trials=n_trials, 1097 best_n=best_n, 1098 parameters_update=pars, 1099 direction=direction, 1100 load_episode=load_episode, 1101 load_epoch=load_epoch, 1102 load_strict=load_strict, 1103 prune=prune, 1104 force=force, 1105 remove_saved_features=remove_saved_features, 1106 ) 1107 1108 def run_hyperparameter_search( 1109 self, 1110 search_name: str, 1111 search_space: Dict, 1112 metric: str = "f1", 1113 n_trials: int = 20, 1114 best_n: int = 1, 1115 parameters_update: Dict = None, 1116 direction: str = "maximize", 1117 load_episode: str = None, 1118 load_epoch: int = None, 1119 load_strict: bool = True, 1120 prune: bool = False, 1121 force: bool = False, 1122 remove_saved_features: bool = False, 1123 ) -> Dict: 1124 """ 1125 Run an optuna hyperparameter search 1126 1127 For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`. 1128 1129 To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is 1130 a dictionary where keys are model names and values are default search spaces. 1131 1132 The task parameters are read from the config files and updated with the parameters_update dictionary. 1133 The model can be either initialized from scratch or loaded from a previously run episode. 1134 For each trial, the objective metric is averaged over a few best epochs. 1135 1136 Parameters 1137 ---------- 1138 search_name : str 1139 the name of the search to store it in the meta files and load in run_episode 1140 search_space : dict 1141 a dictionary representing the search space; of this general structure: 1142 {'group/param_name': ('float/int/float_log/int_log', start, end), 1143 'group/param_name': ('categorical', [choices])}, e.g. 1144 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 1145 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}; 1146 metric : str, default f1 1147 the metric to maximize/minimize (see direction) 1148 n_trials : int, default 20 1149 the number of optimization trials to run 1150 best_n : int, default 1 1151 the number of epochs to average the metric; if 0, the last value is taken 1152 parameters_update : dict, optional 1153 the parameters update dictionary 1154 direction : {'maximize', 'minimize'} 1155 optimization direction 1156 load_episode : str, optional 1157 the name of the episode to load the model from 1158 load_epoch : int, optional 1159 the epoch to load the model from (if not provided, the last checkpoint is used) 1160 prune : bool, default False 1161 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1162 (with optuna HyperBand pruner) 1163 force : bool, default False 1164 if `True`, existing searches with the same name will be overwritten 1165 remove_saved_features : bool, default False 1166 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1167 1168 Returns 1169 ------- 1170 dict 1171 a dictionary of best parameters 1172 """ 1173 1174 self._check_search_validity(search_name, force=force) 1175 print(f"SEARCH {search_name}") 1176 self.remove_episode(f"_{search_name}") 1177 if parameters_update is None: 1178 parameters_update = {} 1179 parameters_update = self._update( 1180 parameters_update, {"general": {"metric_functions": {metric}}} 1181 ) 1182 parameters = self._make_parameters( 1183 f"_{search_name}", 1184 load_episode, 1185 parameters_update, 1186 parameters_update_second={"training": {"model_save_path": None}}, 1187 load_epoch=load_epoch, 1188 load_strict=load_strict, 1189 ) 1190 task = None 1191 1192 if prune: 1193 pruner = optuna.pruners.HyperbandPruner() 1194 else: 1195 pruner = optuna.pruners.NopPruner() 1196 study = optuna.create_study(direction=direction, pruner=pruner) 1197 runner = _Runner( 1198 search_space=search_space, 1199 load_episode=load_episode, 1200 load_epoch=load_epoch, 1201 metric=metric, 1202 average=best_n, 1203 task=task, 1204 remove_saved_features=remove_saved_features, 1205 project=self, 1206 search_name=search_name, 1207 ) 1208 study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials) 1209 search_path = self._search_path(search_name) 1210 os.mkdir(search_path) 1211 fig = optuna.visualization.plot_contour(study) 1212 plotly.offline.plot( 1213 fig, filename=os.path.join(search_path, f"{search_name}_contour.html") 1214 ) 1215 fig = optuna.visualization.plot_param_importances(study) 1216 plotly.offline.plot( 1217 fig, filename=os.path.join(search_path, f"{search_name}_importances.html") 1218 ) 1219 best_params = study.best_params 1220 best_value = study.best_value 1221 self._save_search( 1222 search_name, 1223 parameters, 1224 n_trials, 1225 best_params, 1226 best_value, 1227 metric, 1228 search_space, 1229 ) 1230 self.remove_episode(f"_{search_name}") 1231 runner.clean() 1232 print(f"best parameters: {best_params}") 1233 print("\n") 1234 return best_params 1235 1236 def run_prediction( 1237 self, 1238 prediction_name: str, 1239 episode_names: List, 1240 load_epochs: List = None, 1241 parameters_update: Dict = None, 1242 augment_n: int = 10, 1243 data_path: str = None, 1244 mode: str = "all", 1245 file_paths: Set = None, 1246 remove_saved_features: bool = False, 1247 submission: bool = False, 1248 frame_number_map_file: str = None, 1249 force: bool = False, 1250 embedding: bool = False, 1251 ) -> None: 1252 """ 1253 Load models from previously run episodes to generate a prediction 1254 1255 The probabilities predicted by the models are averaged. 1256 Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder 1257 under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level 1258 keys are the video ids, the second-level keys are the clip ids (like individual names) and the values 1259 are the prediction arrays. 1260 1261 Parameters 1262 ---------- 1263 prediction_name : str 1264 the name of the prediction 1265 episode_names : list 1266 a list of string episode names to load the models from 1267 load_epochs : list, optional 1268 a list of integer epoch indices to load the model from; if None, the last ones are used 1269 parameters_update : dict, optional 1270 a dictionary of parameter updates 1271 augment_n : int, default 10 1272 the number of augmentations to average over 1273 data_path : str, optional 1274 the data path to run the prediction for 1275 mode : {'all', 'test', 'val', 'train'} 1276 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1277 file_paths : set, optional 1278 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1279 for 1280 remove_saved_features : bool, default False 1281 if `True`, pre-computed features will be deleted 1282 submission : bool, default False 1283 if `True`, a MABe-22 style submission file is generated 1284 frame_number_map_file : str, optional 1285 path to the frame number map file 1286 force : bool, default False 1287 if `True`, existing prediction with this name will be overwritten 1288 """ 1289 1290 self._check_prediction_validity(prediction_name, force=force) 1291 print(f"PREDICTION {prediction_name}") 1292 if submission: 1293 task = ... 1294 # TODO: add submission option to _make_prediction 1295 predicted = task.generate_submission( 1296 frame_number_map_file=frame_number_map_file, 1297 dataset=mode, 1298 augment_n=augment_n, 1299 ) 1300 folder = os.path.join( 1301 self.project_path, 1302 "results", 1303 "predictions", 1304 f"{prediction_name}", 1305 ) 1306 filename = os.path.join(folder, f"{prediction_name}.npy") 1307 np.save(filename, predicted, allow_pickle=True) 1308 else: 1309 try: 1310 ( 1311 task, 1312 parameters, 1313 mode, 1314 prediction, 1315 inference_time, 1316 ) = self._make_prediction( 1317 prediction_name, 1318 episode_names, 1319 load_epochs, 1320 parameters_update, 1321 data_path, 1322 file_paths, 1323 mode, 1324 augment_n, 1325 evaluate=False, 1326 embedding=embedding, 1327 ) 1328 predicted = task.dataset(mode).generate_full_length_prediction( 1329 prediction 1330 ) 1331 except ValueError: 1332 ( 1333 task, 1334 parameters, 1335 mode, 1336 predicted, 1337 inference_time, 1338 ) = self._aggregate_predictions( 1339 prediction_name, 1340 episode_names, 1341 load_epochs, 1342 parameters_update, 1343 data_path, 1344 file_paths, 1345 mode, 1346 augment_n, 1347 evaluate=False, 1348 embedding=embedding, 1349 ) 1350 folder = self.prediction_path(prediction_name) 1351 os.mkdir(folder) 1352 for video_id, prediction in predicted.items(): 1353 with open( 1354 os.path.join( 1355 folder, video_id + f"_{prediction_name}_prediction.pickle" 1356 ), 1357 "wb", 1358 ) as f: 1359 prediction["min_frames"], prediction["max_frames"] = task.dataset( 1360 mode 1361 ).get_min_max_frames(video_id) 1362 behavior_indices = sorted( 1363 [key for key in task.behaviors_dict() if key != -100] 1364 ) 1365 prediction["behaviors"] = [ 1366 task.behaviors_dict()[key] for key in behavior_indices 1367 ] 1368 pickle.dump(prediction, f) 1369 if remove_saved_features: 1370 self._remove_stores(parameters) 1371 self._save_prediction( 1372 prediction_name, 1373 parameters, 1374 task.behaviors_dict(), 1375 embedding, 1376 inference_time, 1377 ) 1378 print("\n") 1379 1380 def evaluate_prediction( 1381 self, 1382 prediction_name: str, 1383 parameters_update: Dict = None, 1384 data_path: str = None, 1385 file_paths: Set = None, 1386 mode: str = None, 1387 remove_saved_features: bool = False, 1388 ) -> Tuple[float, dict]: 1389 1390 with open( 1391 os.path.join( 1392 self.project_path, "results", "predictions", f"{prediction_name}.pickle" 1393 ), 1394 "rb", 1395 ) as f: 1396 prediction = pickle.load(f) 1397 if parameters_update is None: 1398 parameters_update = {} 1399 parameters_update = self._update( 1400 self._predictions().load_parameters(prediction_name), parameters_update 1401 ) 1402 parameters_update.pop("model") 1403 task, parameters, mode = self._make_task_prediction( 1404 "_", 1405 load_episode=None, 1406 parameters_update=parameters_update, 1407 data_path=data_path, 1408 file_paths=file_paths, 1409 mode=mode, 1410 ) 1411 results = task.evaluate_prediction(prediction, data=mode) 1412 if remove_saved_features: 1413 self._remove_stores(parameters) 1414 print("\n") 1415 return results 1416 1417 def evaluate( 1418 self, 1419 episode_names: List, 1420 load_epochs: List = None, 1421 augment_n: int = 0, 1422 data_path: str = None, 1423 file_paths: Set = None, 1424 mode: str = None, 1425 parameters_update: Dict = None, 1426 multiple_episode_policy: str = "average", 1427 remove_saved_features: bool = False, 1428 skip_updating_meta: bool = True, 1429 ) -> Dict: 1430 """ 1431 Load one or several models from previously run episodes to make an evaluation 1432 1433 By default it will run on the test (or validation, if there is no test) subset of the project dataset. 1434 1435 Parameters 1436 ---------- 1437 episode_names : list 1438 a list of string episode names to load the models from 1439 load_epochs : list, optional 1440 a list of integer epoch indices to load the model from; if None, the last ones are used 1441 augment_n : int, default 0 1442 the number of augmentations to average over 1443 data_path : str, optional 1444 the data path to run the prediction for 1445 file_paths : set, optional 1446 a set of files to run the prediction for 1447 mode : {'test', 'val', 'train', 'all'} 1448 the subset of the data to make the prediction for (forced to 'all' if data_path is not None; 1449 by default 'test' if test subset is not empty and 'val' otherwise) 1450 parameters_update : dict, optional 1451 a dictionary with parameter updates (cannot change model parameters) 1452 remove_saved_features : bool, default False 1453 if `True`, the dataset will be deleted 1454 1455 Returns 1456 ------- 1457 metric : dict 1458 a dictionary of average values of metric functions 1459 """ 1460 1461 names = [] 1462 for episode_name in episode_names: 1463 names += self._episodes().get_runs(episode_name) 1464 if len(set(episode_names)) == 1: 1465 print(f"EVALUATION {episode_names[0]}") 1466 else: 1467 print(f"EVALUATION {episode_names}") 1468 if len(names) > 1: 1469 evaluate = True 1470 else: 1471 evaluate = False 1472 if multiple_episode_policy == "average": 1473 try: 1474 ( 1475 task, 1476 parameters, 1477 mode, 1478 prediction, 1479 inference_time, 1480 ) = self._make_prediction( 1481 "_", 1482 episode_names, 1483 load_epochs, 1484 parameters_update, 1485 mode=mode, 1486 data_path=data_path, 1487 file_paths=file_paths, 1488 augment_n=augment_n, 1489 evaluate=evaluate, 1490 ) 1491 except: 1492 ( 1493 task, 1494 parameters, 1495 mode, 1496 prediction, 1497 inference_time, 1498 ) = self._aggregate_predictions( 1499 "_", 1500 episode_names, 1501 load_epochs, 1502 parameters_update, 1503 mode=mode, 1504 data_path=data_path, 1505 file_paths=file_paths, 1506 augment_n=augment_n, 1507 evaluate=evaluate, 1508 ) 1509 print("AGGREGATED:") 1510 _, results = task.evaluate_prediction(prediction, data=mode) 1511 if len(names) == 1 and mode == "val" and not skip_updating_meta: 1512 self._update_episode_metrics(names[0], results) 1513 elif multiple_episode_policy == "statistics": 1514 values = defaultdict(lambda: []) 1515 task = None 1516 for name in names: 1517 ( 1518 task, 1519 parameters, 1520 mode, 1521 prediction, 1522 inference_time, 1523 ) = self._make_prediction( 1524 "_", 1525 [name], 1526 load_epochs, 1527 parameters_update, 1528 mode=mode, 1529 data_path=data_path, 1530 file_paths=file_paths, 1531 augment_n=augment_n, 1532 evaluate=evaluate, 1533 task=task, 1534 ) 1535 _, metrics = task.evaluate_prediction(prediction, data=mode) 1536 for name, value in metrics.items(): 1537 values[name].append(value) 1538 if mode == "val" and not skip_updating_meta: 1539 self._update_episode_metrics(name, metrics) 1540 results = defaultdict(lambda: {}) 1541 mean_string = "" 1542 std_string = "" 1543 for key, value_list in values.items(): 1544 results[key]["mean"] = np.mean(value_list) 1545 results[key]["std"] = np.std(value_list) 1546 mean_string += f"{key} {np.mean(value_list):.3f}, " 1547 std_string += f"{key} {np.std(value_list):.3f}, " 1548 print("MEAN:") 1549 print(mean_string) 1550 print("STD:") 1551 print(std_string) 1552 else: 1553 raise ValueError( 1554 f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose " 1555 f"from ['average', 'statistics']" 1556 ) 1557 if len(names) > 0 and remove_saved_features: 1558 self._remove_stores(parameters) 1559 print(f"Inference time: {inference_time}") 1560 print("\n") 1561 return results 1562 1563 def _generate_similarity_score( 1564 self, 1565 prediction_name: str, 1566 target_video_id: str, 1567 target_clip: str, 1568 target_start: int, 1569 target_end: int, 1570 ) -> Dict: 1571 with open( 1572 os.path.join( 1573 self.project_path, 1574 "results", 1575 "predictions", 1576 f"{prediction_name}.pickle", 1577 ), 1578 "rb", 1579 ) as f: 1580 prediction = pickle.load(f) 1581 target = prediction[target_video_id][target_clip][:, target_start:target_end] 1582 score_dict = defaultdict(lambda: {}) 1583 for video_id in prediction: 1584 for clip_id in prediction[video_id]: 1585 score_dict[video_id][clip_id] = torch.cdist( 1586 target.T, prediction[video_id][score_dict].T 1587 ).min(0) 1588 return score_dict 1589 1590 def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict: 1591 interval_address = {} 1592 interval_value = {} 1593 s = 0 1594 n = 0 1595 for video_id, video_dict in score_dict.items(): 1596 for clip_id, value in video_dict.items(): 1597 s += value.mean() 1598 n += 1 1599 mean_value = s / n 1600 alpha = 1.75 1601 for it in range(10): 1602 id = 0 1603 interval_address = {} 1604 interval_value = {} 1605 for video_id, video_dict in score_dict.items(): 1606 for clip_id, value in video_dict.items(): 1607 res_indices_start, res_indices_end = apply_threshold( 1608 value, 1609 threshold=(2 - alpha * (0.9**it)) * mean_value, 1610 low=True, 1611 error_mask=None, 1612 min_frames=min_length, 1613 smooth_interval=0, 1614 ) 1615 for start, end in zip(res_indices_start, res_indices_end): 1616 interval_address[id] = [video_id, clip_id, start, end] 1617 interval_value[id] = score_dict[video_id][clip_id][ 1618 start:end 1619 ].mean() 1620 id += 1 1621 if len(interval_address) >= n_intervals: 1622 break 1623 if len(interval_address) < n_intervals: 1624 warnings.warn( 1625 f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals" 1626 ) 1627 sorted_intervals = sorted( 1628 interval_value.items(), key=lambda x: x[1], reverse=True 1629 ) 1630 output_intervals = [ 1631 interval_address[x[0]] 1632 for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)] 1633 ] 1634 output = defaultdict(lambda: []) 1635 for video_id, clip_id, start, end in output_intervals: 1636 output[video_id].append([start, end, clip_id]) 1637 return output 1638 1639 def list_episodes( 1640 self, 1641 episode_names: List = None, 1642 value_filter: str = "", 1643 display_parameters: List = None, 1644 print_results: bool = True, 1645 ) -> pd.DataFrame: 1646 """ 1647 Get a filtered pandas dataframe with episode metadata 1648 1649 Parameters 1650 ---------- 1651 episode_names : list 1652 a list of strings of episode names 1653 value_filter : str 1654 a string of filters to apply; of this general structure: 1655 'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g. 1656 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' 1657 display_parameters : list 1658 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1659 print_results : bool, default True 1660 if True, the result will be printed to standard output 1661 1662 Returns 1663 ------- 1664 pd.DataFrame 1665 the filtered dataframe 1666 """ 1667 1668 episodes = self._episodes().list_episodes( 1669 episode_names, value_filter, display_parameters 1670 ) 1671 if print_results: 1672 print("TRAINING EPISODES") 1673 print(episodes) 1674 print("\n") 1675 return episodes 1676 1677 def list_predictions( 1678 self, 1679 episode_names: List = None, 1680 value_filter: str = "", 1681 display_parameters: List = None, 1682 print_results: bool = True, 1683 ) -> pd.DataFrame: 1684 """ 1685 Get a filtered pandas dataframe with prediction metadata 1686 1687 Parameters 1688 ---------- 1689 episode_names : list 1690 a list of strings of episode names 1691 value_filter : str 1692 a string of filters to apply; of this general structure: 1693 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 1694 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 1695 display_parameters : list 1696 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1697 print_results : bool, default True 1698 if True, the result will be printed to standard output 1699 1700 Returns 1701 ------- 1702 pd.DataFrame 1703 the filtered dataframe 1704 """ 1705 1706 predictions = self._predictions().list_episodes( 1707 episode_names, value_filter, display_parameters 1708 ) 1709 if print_results: 1710 print("PREDICTIONS") 1711 print(predictions) 1712 print("\n") 1713 return predictions 1714 1715 def list_searches( 1716 self, 1717 search_names: List = None, 1718 value_filter: str = "", 1719 display_parameters: List = None, 1720 print_results: bool = True, 1721 ) -> pd.DataFrame: 1722 """ 1723 Get a filtered pandas dataframe with hyperparameter search metadata 1724 1725 Parameters 1726 ---------- 1727 search_names : list 1728 a list of strings of search names 1729 value_filter : str 1730 a string of filters to apply; of this general structure: 1731 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 1732 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 1733 display_parameters : list 1734 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1735 print_results : bool, default True 1736 if True, the result will be printed to standard output 1737 1738 Returns 1739 ------- 1740 pd.DataFrame 1741 the filtered dataframe 1742 """ 1743 1744 searches = self._searches().list_episodes( 1745 search_names, value_filter, display_parameters 1746 ) 1747 if print_results: 1748 print("SEARCHES") 1749 print(searches) 1750 print("\n") 1751 return searches 1752 1753 def get_best_parameters( 1754 self, 1755 search_name: str, 1756 round_to_binary: List = None, 1757 ): 1758 params, model = self._searches().get_best_params( 1759 search_name, round_to_binary=round_to_binary 1760 ) 1761 params = self._update(params, {"general": {"model_name": model}}) 1762 return params 1763 1764 def list_best_parameters( 1765 self, search_name: str, print_results: bool = True 1766 ) -> Dict: 1767 """ 1768 Get the raw dictionary of best parameters found by a search 1769 1770 Parameters 1771 ---------- 1772 search_name : str 1773 the name of the search 1774 print_results : bool, default True 1775 if True, the result will be printed to standard output 1776 1777 Returns 1778 ------- 1779 best_params : dict 1780 a dictionary of the best parameters where the keys are in '{group}/{name}' format 1781 """ 1782 1783 params = self._searches().get_best_params_raw(search_name) 1784 if print_results: 1785 print(f"SEARCH RESULTS {search_name}") 1786 for k, v in params.items(): 1787 print(f"{k}: {v}") 1788 print("\n") 1789 return params 1790 1791 def plot_episodes( 1792 self, 1793 episode_names: List, 1794 metrics: List, 1795 modes: List = None, 1796 title: str = None, 1797 episode_labels: List = None, 1798 save_path: str = None, 1799 add_hlines: List = None, 1800 epoch_limits: List = None, 1801 colors: List = None, 1802 add_highpoint_hlines: bool = False, 1803 ) -> None: 1804 """ 1805 Plot episode training curves 1806 1807 Parameters 1808 ---------- 1809 episode_names : list 1810 a list of episode names to plot; to plot to episodes in one line combine them in a list 1811 (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) 1812 metrics : list 1813 a list of metric to plot 1814 modes : list, optional 1815 a list of modes to plot ('train' and/or 'val'; `['val']` by default) 1816 title : str, optional 1817 title for the plot 1818 episode_labels : list, optional 1819 a list of strings used to label the curves (has to be the same length as episode_names) 1820 save_path : str, optional 1821 the path to save the resulting plot 1822 add_hlines : list, optional 1823 a list of float values (or (value, label) tuples) to mark with horizontal lines 1824 colors: list, optional 1825 a list of matplotlib colors 1826 add_highpoint_hlines : bool, default False 1827 if `True`, horizontal lines will be added at the highest value of each episode 1828 """ 1829 1830 if modes is None: 1831 modes = ["val"] 1832 if add_hlines is None: 1833 add_hlines = [] 1834 logs = [] 1835 epochs = [] 1836 labels = [] 1837 if episode_labels is not None: 1838 assert len(episode_labels) == len(episode_names) 1839 for name_i, name in enumerate(episode_names): 1840 log_params = product(metrics, modes) 1841 for metric, mode in log_params: 1842 if episode_labels is not None: 1843 label = episode_labels[name_i] 1844 else: 1845 label = deepcopy(name) 1846 if len(modes) != 1: 1847 label += f"_{mode}" 1848 if len(metrics) != 1: 1849 label += f"_{metric}" 1850 labels.append(label) 1851 if isinstance(name, Iterable) and not isinstance(name, str): 1852 epoch_list = defaultdict(lambda: []) 1853 multi_logs = defaultdict(lambda: []) 1854 for i, n in enumerate(name): 1855 runs = self._episodes().get_runs(n) 1856 if len(runs) > 1: 1857 for run in runs: 1858 index = run.split("::")[-1] 1859 if multi_logs[index] == []: 1860 if multi_logs["null"] is None: 1861 raise RuntimeError( 1862 "The run indices are not consistent across episodes!" 1863 ) 1864 else: 1865 multi_logs[index] += multi_logs["null"] 1866 multi_logs[index] += list( 1867 self._episode(run).get_metric_log(mode, metric) 1868 ) 1869 start = ( 1870 0 1871 if len(epoch_list[index]) == 0 1872 else epoch_list[index][-1] 1873 ) 1874 epoch_list[index] += [ 1875 x + start 1876 for x in self._episode(run).get_epoch_list(mode) 1877 ] 1878 multi_logs["null"] = None 1879 else: 1880 if len(multi_logs.keys()) > 1: 1881 raise RuntimeError( 1882 "Cannot plot a single-run episode after a multi-run episode!" 1883 ) 1884 multi_logs["null"] += list( 1885 self._episode(n).get_metric_log(mode, metric) 1886 ) 1887 start = ( 1888 0 1889 if len(epoch_list["null"]) == 0 1890 else epoch_list["null"][-1] 1891 ) 1892 epoch_list["null"] += [ 1893 x + start for x in self._episode(n).get_epoch_list(mode) 1894 ] 1895 if len(multi_logs.keys()) == 1: 1896 log = multi_logs["null"] 1897 epochs.append(epoch_list["null"]) 1898 else: 1899 log = tuple([v for k, v in multi_logs.items() if k != "null"]) 1900 epochs.append( 1901 tuple([v for k, v in epoch_list.items() if k != "null"]) 1902 ) 1903 else: 1904 runs = self._episodes().get_runs(name) 1905 if len(runs) > 1: 1906 log = [] 1907 for run in runs: 1908 tracked_metrics = self._episode(run).get_metrics() 1909 if metric in tracked_metrics: 1910 log.append( 1911 list( 1912 self._episode(run).get_metric_log(mode, metric) 1913 ) 1914 ) 1915 else: 1916 relevant = [] 1917 for m in tracked_metrics: 1918 m_split = m.split("_") 1919 if ( 1920 "_".join(m_split[:-1]) == metric 1921 and m_split[-1].isnumeric() 1922 ): 1923 relevant.append(m) 1924 if len(relevant) == 0: 1925 raise ValueError( 1926 f"The {metric} metric was not tracked at {run}" 1927 ) 1928 arr = 0 1929 for m in relevant: 1930 arr += self._episode(run).get_metric_log(mode, m) 1931 arr /= len(relevant) 1932 log.append(list(arr)) 1933 log = tuple(log) 1934 epochs.append( 1935 tuple( 1936 [ 1937 self._episode(run).get_epoch_list(mode) 1938 for run in runs 1939 ] 1940 ) 1941 ) 1942 else: 1943 tracked_metrics = self._episode(name).get_metrics() 1944 if metric in tracked_metrics: 1945 log = list(self._episode(name).get_metric_log(mode, metric)) 1946 else: 1947 relevant = [] 1948 for m in tracked_metrics: 1949 m_split = m.split("_") 1950 if ( 1951 "_".join(m_split[:-1]) == metric 1952 and m_split[-1].isnumeric() 1953 ): 1954 relevant.append(m) 1955 if len(relevant) == 0: 1956 raise ValueError( 1957 f"The {metric} metric was not tracked at {name}" 1958 ) 1959 arr = 0 1960 for m in relevant: 1961 arr += self._episode(name).get_metric_log(mode, m) 1962 arr /= len(relevant) 1963 log = list(arr) 1964 epochs.append(self._episode(name).get_epoch_list(mode)) 1965 logs.append(log) 1966 # if episode_labels is not None: 1967 # print(f'{len(episode_labels)=}, {len(logs)=}') 1968 # if len(episode_labels) != len(logs): 1969 1970 # raise ValueError( 1971 # f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of " 1972 # f"curves ({len(logs)})!" 1973 # ) 1974 # else: 1975 # labels = episode_labels 1976 if colors is None: 1977 colors = cm.rainbow(np.linspace(0, 1, len(logs))) 1978 if len(colors) != len(logs): 1979 raise ValueError( 1980 "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!" 1981 ) 1982 plt.figure() 1983 length = 0 1984 for log, label, color, epoch_list in zip(logs, labels, colors, epochs): 1985 if type(log) is list: 1986 if len(log) > length: 1987 length = len(log) 1988 plt.plot( 1989 epoch_list, 1990 log, 1991 label=label, 1992 color=color, 1993 ) 1994 if add_highpoint_hlines: 1995 plt.axhline(np.max(log), linestyle="dashed", color=color) 1996 else: 1997 for l, xx in zip(log, epoch_list): 1998 if len(l) > length: 1999 length = len(l) 2000 plt.plot( 2001 xx, 2002 l, 2003 color=color, 2004 alpha=0.2, 2005 ) 2006 if not all([len(x) == len(log[0]) for x in log]): 2007 warnings.warn( 2008 f"Got logs with unequal lengths in parallel runs for {label}" 2009 ) 2010 log = list(log) 2011 epoch_list = list(epoch_list) 2012 for i, x in enumerate(epoch_list): 2013 to_remove = [] 2014 for j, y in enumerate(x[1:]): 2015 if y <= x[j - 1]: 2016 y_ind = x.index(y) 2017 to_remove += list(range(y_ind, j)) 2018 epoch_list[i] = [ 2019 y for j, y in enumerate(x) if j not in to_remove 2020 ] 2021 log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove] 2022 length = min([len(x) for x in log]) 2023 for i in range(len(log)): 2024 log[i] = log[i][:length] 2025 epoch_list[i] = epoch_list[i][:length] 2026 if not all([x == epoch_list[0] for x in epoch_list]): 2027 raise RuntimeError( 2028 f"Got different epoch indices in parallel runs for {label}" 2029 ) 2030 mean = np.array(log).mean(0) 2031 plt.plot( 2032 epoch_list[0], 2033 mean, 2034 label=label, 2035 color=color, 2036 ) 2037 if add_highpoint_hlines: 2038 plt.axhline(np.max(mean), linestyle="dashed", color=color) 2039 for x in add_hlines: 2040 label = None 2041 if isinstance(x, Iterable): 2042 x, label = x 2043 plt.axhline(x, label=label) 2044 plt.xlim((0, length)) 2045 2046 plt.legend() 2047 plt.xlabel("epochs") 2048 if len(metrics) == 1: 2049 plt.ylabel(metrics[0]) 2050 else: 2051 plt.ylabel("value") 2052 if title is None: 2053 if len(episode_names) == 1: 2054 title = episode_names[0] 2055 elif len(metrics) == 1: 2056 title = metrics[0] 2057 if epoch_limits is not None: 2058 plt.xlim(epoch_limits) 2059 if title is not None: 2060 plt.title(title) 2061 plt.show() 2062 if save_path is not None: 2063 plt.savefig(save_path) 2064 2065 def update_parameters( 2066 self, 2067 parameters_update: Dict = None, 2068 load_search: str = None, 2069 load_parameters: List = None, 2070 round_to_binary: List = None, 2071 ) -> None: 2072 """ 2073 Update the parameters in the project config files 2074 2075 Parameters 2076 ---------- 2077 parameters_update : dict, optional 2078 a dictionary of parameter updates 2079 load_search : str, optional 2080 the name of hyperparameter search results to load to config 2081 load_parameters : list, optional 2082 a list of lists of string names of the parameters to load from the searches 2083 round_to_binary : list, optional 2084 a list of string names of the loaded parameters that should be rounded to the nearest power of two 2085 """ 2086 2087 keys = [ 2088 "general", 2089 "losses", 2090 "metrics", 2091 "ssl", 2092 "training", 2093 "data", 2094 ] 2095 parameters = self._read_parameters(catch_blanks=False) 2096 if parameters_update is not None: 2097 if "model" in parameters_update: 2098 model_params = parameters_update.pop("model") 2099 else: 2100 model_params = None 2101 if "features" in parameters_update: 2102 feat_params = parameters_update.pop("features") 2103 else: 2104 feat_params = None 2105 if "augmentations" in parameters_update: 2106 aug_params = parameters_update.pop("augmentations") 2107 else: 2108 aug_params = None 2109 parameters = self._update(parameters, parameters_update) 2110 model_name = parameters["general"]["model_name"] 2111 parameters["model"] = self._open_yaml( 2112 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2113 ) 2114 if model_params is not None: 2115 parameters["model"] = self._update(parameters["model"], model_params) 2116 feat_name = parameters["general"]["feature_extraction"] 2117 parameters["features"] = self._open_yaml( 2118 os.path.join( 2119 self.project_path, "config", "features", f"{feat_name}.yaml" 2120 ) 2121 ) 2122 if feat_params is not None: 2123 parameters["features"] = self._update( 2124 parameters["features"], feat_params 2125 ) 2126 aug_name = options.extractor_to_transformer[ 2127 parameters["general"]["feature_extraction"] 2128 ] 2129 parameters["augmentations"] = self._open_yaml( 2130 os.path.join( 2131 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2132 ) 2133 ) 2134 if aug_params is not None: 2135 parameters["augmentations"] = self._update( 2136 parameters["augmentations"], aug_params 2137 ) 2138 if load_search is not None: 2139 parameters_update, model_name = self._searches().get_best_params( 2140 load_search, load_parameters, round_to_binary 2141 ) 2142 parameters["general"]["model_name"] = model_name 2143 parameters["model"] = self._open_yaml( 2144 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2145 ) 2146 parameters = self._update(parameters, parameters_update) 2147 for key in keys: 2148 with open( 2149 os.path.join(self.project_path, "config", f"{key}.yaml"), "w", encoding="utf-8" 2150 ) as f: 2151 YAML().dump(parameters[key], f) 2152 model_name = parameters["general"]["model_name"] 2153 model_path = os.path.join( 2154 self.project_path, "config", "model", f"{model_name}.yaml" 2155 ) 2156 with open(model_path, "w", encoding="utf-8") as f: 2157 YAML().dump(parameters["model"], f) 2158 features_name = parameters["general"]["feature_extraction"] 2159 features_path = os.path.join( 2160 self.project_path, "config", "features", f"{features_name}.yaml" 2161 ) 2162 with open(features_path, "w", encoding="utf-8") as f: 2163 YAML().dump(parameters["features"], f) 2164 aug_name = options.extractor_to_transformer[features_name] 2165 aug_path = os.path.join( 2166 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2167 ) 2168 with open(aug_path, "w", encoding="utf-8") as f: 2169 YAML().dump(parameters["augmentations"], f) 2170 2171 def get_summary( 2172 self, 2173 episode_names: list, 2174 method: str = "last", 2175 average: int = 1, 2176 metrics: List = None, 2177 ) -> Dict: 2178 """ 2179 Get a summary of episode statistics 2180 2181 If the episode has multiple runs, the statistics will be aggregated over all of them. 2182 2183 Parameters 2184 ---------- 2185 episode_name : str 2186 the name of the episode 2187 method : ["best", "last"] 2188 the method for choosing the epochs 2189 average : int, default 1 2190 the number of epochs to average over (for each run) 2191 metrics : list, optional 2192 a list of metrics 2193 2194 Returns 2195 ------- 2196 statistics : dict 2197 a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean 2198 and 'std' for the standard deviation 2199 """ 2200 2201 runs = [] 2202 for episode_name in episode_names: 2203 runs_ep = self._episodes().get_runs(episode_name) 2204 if len(runs_ep) == 0: 2205 raise RuntimeError( 2206 f"There is no {episode_name} episode in the project memory" 2207 ) 2208 runs += runs_ep 2209 if metrics is None: 2210 metrics = self._episode(runs[0]).get_metrics() 2211 2212 values = {m: [] for m in metrics} 2213 for run in runs: 2214 for m in metrics: 2215 log = self._episode(run).get_metric_log(mode="val", metric_name=m) 2216 if method == "best": 2217 log = sorted(log) 2218 values[m] += list(log[-average:]) 2219 elif method == "last": 2220 if len(log) == 0: 2221 episodes = self._episodes().data 2222 if average == 1 and ("results", m) in episodes.columns: 2223 values[m] += [episodes.loc[run, ("results", m)]] 2224 else: 2225 raise RuntimeError(f"Did not find {m} metric for {run} run") 2226 values[m] += list(log[-average:]) 2227 elif method.startswith("epoch"): 2228 epoch = int(method[5:]) - 1 2229 pars = self._episodes().load_parameters(run) 2230 step = int(pars["training"]["validation_interval"]) 2231 values[m] += [log[epoch // step]] 2232 else: 2233 raise ValueError( 2234 f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']" 2235 ) 2236 statistics = defaultdict(lambda: {}) 2237 for m, v in values.items(): 2238 statistics[m]["mean"] = np.mean(v) 2239 statistics[m]["std"] = np.std(v) 2240 print(f"SUMMARY {episode_names}") 2241 for m, v in statistics.items(): 2242 print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}') 2243 print("\n") 2244 return dict(statistics) 2245 2246 @staticmethod 2247 def remove_project(name: str, projects_path: str = None) -> None: 2248 """ 2249 Remove all project files and experiment records and results 2250 """ 2251 2252 if projects_path is None: 2253 projects_path = os.path.join(str(Path.home()), "DLC2Action") 2254 project_path = os.path.join(projects_path, name) 2255 if os.path.exists(project_path): 2256 shutil.rmtree(project_path) 2257 2258 def remove_saved_features( 2259 self, 2260 dataset_names: List = None, 2261 exceptions: List = None, 2262 remove_active: bool = False, 2263 ) -> None: 2264 """ 2265 Remove saved pre-computed dataset files 2266 2267 By default, all pre-computed features will be deleted. 2268 No essential information can get lost, storing them only saves time. Be careful with deleting datasets 2269 while training or inference is happening though. 2270 2271 Parameters 2272 ---------- 2273 dataset_names : list, optional 2274 a list of dataset names to delete (by default all names are added) 2275 exceptions : list, optional 2276 a list of dataset names to not be deleted 2277 remove_active : bool, default False 2278 if `False`, datasets used by unfinished episodes will not be deleted 2279 """ 2280 2281 print("Removing datasets...") 2282 if dataset_names is None: 2283 dataset_names = [] 2284 if exceptions is None: 2285 exceptions = [] 2286 if not remove_active: 2287 exceptions += self._episodes().get_active_datasets() 2288 dataset_path = os.path.join(self.project_path, "saved_datasets") 2289 if os.path.exists(dataset_path): 2290 if dataset_names == []: 2291 dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)]) 2292 2293 to_remove = [ 2294 x 2295 for x in dataset_names 2296 if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions 2297 ] 2298 if len(to_remove) > 2: 2299 to_remove = tqdm(to_remove) 2300 for dataset in to_remove: 2301 shutil.rmtree(os.path.join(dataset_path, dataset)) 2302 to_remove = [ 2303 f"{x}.pickle" 2304 for x in dataset_names 2305 if os.path.exists(os.path.join(dataset_path, f"{x}.pickle")) 2306 and x not in exceptions 2307 ] 2308 for dataset in to_remove: 2309 os.remove(os.path.join(dataset_path, dataset)) 2310 names = self._saved_datasets().dataset_names() 2311 self._saved_datasets().remove(names) 2312 print("\n") 2313 2314 def remove_extra_checkpoints( 2315 self, episode_names: List = None, exceptions: List = None 2316 ) -> None: 2317 """ 2318 Remove intermediate model checkpoint files (only leave the results of the last epoch) 2319 2320 By default, all intermediate checkpoints will be deleted. 2321 Files in the model folder that are not associated with any record in the meta files are also deleted. 2322 2323 Parameters 2324 ---------- 2325 episode_names : list, optional 2326 a list of episode names to clean (by default all names are added) 2327 exceptions : list, optional 2328 a list of episode names to not clean 2329 """ 2330 2331 model_path = os.path.join(self.project_path, "results", "model") 2332 try: 2333 all_names = self._episodes().data.index 2334 except: 2335 all_names = os.listdir(model_path) 2336 if episode_names is None: 2337 episode_names = all_names 2338 if exceptions is None: 2339 exceptions = [] 2340 to_remove = [x for x in episode_names if x not in exceptions] 2341 folders = os.listdir(model_path) 2342 for folder in folders: 2343 if folder not in all_names: 2344 shutil.rmtree(os.path.join(model_path, folder)) 2345 elif folder in to_remove: 2346 files = os.listdir(os.path.join(model_path, folder)) 2347 for file in sorted(files)[:-1]: 2348 os.remove(os.path.join(model_path, folder, file)) 2349 2350 def remove_search(self, search_name: str) -> None: 2351 """ 2352 Remove a hyperparameter search record 2353 2354 Parameters 2355 ---------- 2356 search_name : str 2357 the name of the search to remove 2358 """ 2359 2360 self._searches().remove_episode(search_name) 2361 graph_path = os.path.join(self.project_path, "results", "searches", search_name) 2362 if os.path.exists(graph_path): 2363 shutil.rmtree(graph_path) 2364 2365 def remove_prediction(self, prediction_name: str) -> None: 2366 """ 2367 Remove a prediction record 2368 2369 Parameters 2370 ---------- 2371 prediction_name : str 2372 the name of the prediction to remove 2373 """ 2374 2375 self._predictions().remove_episode(prediction_name) 2376 prediction_path = os.path.join( 2377 self.project_path, "results", "predictions", prediction_name 2378 ) 2379 if os.path.exists(prediction_path): 2380 shutil.rmtree(prediction_path) 2381 2382 def remove_episode(self, episode_name: str) -> None: 2383 """ 2384 Remove all model, logs and metafile records related to an episode 2385 2386 Parameters 2387 ---------- 2388 episode_name : str 2389 the name of the episode to remove 2390 """ 2391 2392 runs = self._episodes().get_runs(episode_name) 2393 runs.append(episode_name) 2394 for run in runs: 2395 self._episodes().remove_episode(run) 2396 model_path = os.path.join(self.project_path, "results", "model", run) 2397 if os.path.exists(model_path): 2398 shutil.rmtree(model_path) 2399 log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt") 2400 if os.path.exists(log_path): 2401 os.remove(log_path) 2402 2403 def prune_unfinished(self, exceptions: List = None) -> None: 2404 """ 2405 Remove all interrupted episodes 2406 2407 Remove all episodes that either don't have a log file or have less epochs in the log file than in 2408 the training parameters or have a model folder but not a record. Note that it can remove episodes that are 2409 currently running! 2410 2411 Parameters 2412 ---------- 2413 exceptions : list 2414 the episodes to keep even if they are interrupted 2415 2416 Returns 2417 ------- 2418 pruned : list 2419 a list of the episode names that were pruned 2420 """ 2421 2422 if exceptions is None: 2423 exceptions = [] 2424 unfinished = self._episodes().unfinished_episodes() 2425 unfinished = [x for x in unfinished if x not in exceptions] 2426 model_folders = os.listdir(os.path.join(self.project_path, "results", "model")) 2427 unfinished += [ 2428 x for x in model_folders if x not in self._episodes().list_episodes().index 2429 ] 2430 print(f"PRUNING {unfinished}") 2431 for episode_name in unfinished: 2432 self.remove_episode(episode_name) 2433 print(f"\n") 2434 return unfinished 2435 2436 def prediction_path(self, prediction_name: str) -> str: 2437 """ 2438 Get the path where prediction files are saved 2439 2440 Parameters 2441 ---------- 2442 prediction_name : str 2443 name of the prediction 2444 2445 Returns 2446 ------- 2447 prediction_path : str 2448 the file path 2449 """ 2450 2451 return os.path.join( 2452 self.project_path, "results", "predictions", f"{prediction_name}" 2453 ) 2454 2455 @classmethod 2456 def print_data_types(cls): 2457 print("DATA TYPES:") 2458 for key, value in cls.data_types().items(): 2459 print(f"{key}:") 2460 print(value.__doc__) 2461 2462 @classmethod 2463 def print_annotation_types(cls): 2464 print("ANNOTATION TYPES:") 2465 for key, value in cls.annotation_types().items(): 2466 print(f"{key}:") 2467 print(value.__doc__) 2468 2469 @staticmethod 2470 def data_types() -> List: 2471 """ 2472 Get available data types 2473 2474 Returns 2475 ------- 2476 list 2477 available data types 2478 """ 2479 2480 return options.input_stores 2481 2482 @staticmethod 2483 def annotation_types() -> List: 2484 """ 2485 Get available annotation types 2486 2487 Returns 2488 ------- 2489 list 2490 available annotation types 2491 """ 2492 2493 return options.annotation_stores 2494 2495 def _save_mask(self, file: Dict, mask_name: str): 2496 """ 2497 Save a mask file 2498 """ 2499 2500 if not os.path.exists(self._mask_path()): 2501 os.mkdir(self._mask_path()) 2502 with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f: 2503 pickle.dump(file, f) 2504 2505 def _load_mask(self, mask_name: str) -> Dict: 2506 """ 2507 Load a mask file 2508 """ 2509 2510 with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f: 2511 data = pickle.load(f) 2512 return data 2513 2514 def _thresholds(self) -> DecisionThresholds: 2515 """ 2516 Get the decision thresholds meta object 2517 """ 2518 2519 return DecisionThresholds(self._thresholds_path()) 2520 2521 def _episodes(self) -> SavedRuns: 2522 """ 2523 Get the episodes meta object 2524 2525 Returns 2526 ------- 2527 episodes : SavedRuns 2528 the episodes meta object 2529 """ 2530 2531 try: 2532 return SavedRuns(self._episodes_path(), self.project_path) 2533 except: 2534 self.load_metadata_backup() 2535 return SavedRuns(self._episodes_path(), self.project_path) 2536 2537 def _predictions(self) -> SavedRuns: 2538 """ 2539 Get the predictions meta object 2540 2541 Returns 2542 ------- 2543 predictions : SavedRuns 2544 the predictions meta object 2545 """ 2546 2547 try: 2548 return SavedRuns(self._predictions_path(), self.project_path) 2549 except: 2550 self.load_metadata_backup() 2551 return SavedRuns(self._predictions_path(), self.project_path) 2552 2553 def _saved_datasets(self) -> SavedStores: 2554 """ 2555 Get the datasets meta object 2556 2557 Returns 2558 ------- 2559 datasets : SavedStores 2560 the datasets meta object 2561 """ 2562 2563 try: 2564 return SavedStores(self._saved_datasets_path()) 2565 except: 2566 self.load_metadata_backup() 2567 return SavedStores(self._saved_datasets_path()) 2568 2569 def _prediction(self, name: str) -> Run: 2570 """ 2571 Get a prediction meta object 2572 2573 Parameters 2574 ---------- 2575 name : str 2576 episode name 2577 2578 Returns 2579 ------- 2580 prediction : Run 2581 the prediction meta object 2582 """ 2583 2584 try: 2585 return Run(name, self.project_path, meta_path=self._predictions_path()) 2586 except: 2587 self.load_metadata_backup() 2588 return Run(name, self.project_path, meta_path=self._predictions_path()) 2589 2590 def _episode(self, name: str) -> Run: 2591 """ 2592 Get an episode meta object 2593 2594 Parameters 2595 ---------- 2596 name : str 2597 episode name 2598 2599 Returns 2600 ------- 2601 episode : Run 2602 the episode meta object 2603 """ 2604 2605 try: 2606 return Run(name, self.project_path, meta_path=self._episodes_path()) 2607 except: 2608 self.load_metadata_backup() 2609 return Run(name, self.project_path, meta_path=self._episodes_path()) 2610 2611 def _searches(self) -> Searches: 2612 """ 2613 Get the hyperparameter search meta object 2614 2615 Returns 2616 ------- 2617 searches : Searches 2618 the searches meta object 2619 """ 2620 2621 try: 2622 return Searches(self._searches_path(), self.project_path) 2623 except: 2624 self.load_metadata_backup() 2625 return Searches(self._searches_path(), self.project_path) 2626 2627 def _update_configs(self) -> None: 2628 """ 2629 Update the project config files with newly added files and parameters 2630 """ 2631 2632 self.update_parameters({"data": {"data_path": self.data_path}}) 2633 folders = ["augmentations", "features", "model"] 2634 original_path = os.path.join( 2635 os.path.dirname(os.path.dirname(__file__)), "config" 2636 ) 2637 project_path = os.path.join(self.project_path, "config") 2638 filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")] 2639 for folder in folders: 2640 filenames += [ 2641 os.path.join(folder, x) 2642 for x in os.listdir(os.path.join(original_path, folder)) 2643 ] 2644 filenames.append(os.path.join("data", f"{self.data_type}.yaml")) 2645 if self.annotation_type != "none": 2646 filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml")) 2647 for file in filenames: 2648 filepath_original = os.path.join(original_path, file) 2649 if file.startswith("data") or file.startswith("annotation"): 2650 file = os.path.basename(file) 2651 filepath_project = os.path.join(project_path, file) 2652 if not os.path.exists(filepath_project): 2653 shutil.copy(filepath_original, filepath_project) 2654 else: 2655 original_pars = self._open_yaml(filepath_original) 2656 project_pars = self._open_yaml(filepath_project) 2657 to_remove = [] 2658 for key, value in project_pars.items(): 2659 if key not in original_pars: 2660 if key not in ["data_type", "annotation_type"]: 2661 to_remove.append(key) 2662 for key in to_remove: 2663 project_pars.pop(key) 2664 to_remove = [] 2665 for key, value in original_pars.items(): 2666 if key in project_pars: 2667 to_remove.append(key) 2668 for key in to_remove: 2669 original_pars.pop(key) 2670 project_pars = self._update(project_pars, original_pars) 2671 with open(filepath_project, "w", encoding="utf-8") as f: 2672 YAML().dump(project_pars, f) 2673 2674 def _update_project(self) -> None: 2675 """ 2676 Update project files with the current version 2677 """ 2678 2679 version_file = self._version_path() 2680 ok = True 2681 if not os.path.exists(version_file): 2682 ok = False 2683 else: 2684 with open(version_file) as f: 2685 project_version = f.read() 2686 if project_version < __version__: 2687 ok = False 2688 elif project_version > __version__: 2689 warnings.warn( 2690 f"The project expects a higher dlc2action version ({project_version}), please update!" 2691 ) 2692 if not ok: 2693 project_config_path = os.path.join(self.project_path, "config") 2694 config_path = os.path.join( 2695 os.path.dirname(os.path.dirname(__path__)), "config" 2696 ) 2697 episodes = self._episodes() 2698 folders = ["annotation", "augmentations", "data", "features", "model"] 2699 2700 project_annotation_configs = os.listdir( 2701 os.path.join(project_config_path, "annotation") 2702 ) 2703 annotation_configs = os.listdir(os.path.join(config_path, "annotation")) 2704 for ann_config in annotation_configs: 2705 if ann_config not in project_annotation_configs: 2706 shutil.copytree( 2707 os.path.join(config_path, "annotation", ann_config), 2708 os.path.join(project_config_path, "annotation", ann_config), 2709 dirs_exist_ok=True, 2710 ) 2711 else: 2712 project_pars = self._open_yaml( 2713 os.path.join(project_config_path, "annotation", ann_config) 2714 ) 2715 pars = self._open_yaml( 2716 os.path.join(config_path, "annotation", ann_config) 2717 ) 2718 new_keys = set(pars.keys()) - set(project_pars.keys()) 2719 for key in new_keys: 2720 project_pars[key] = pars[key] 2721 c = self._get_comment(pars.ca.items.get(key)) 2722 project_pars.yaml_add_eol_comment(c, key=key) 2723 episodes.update( 2724 condition=f"general/annotation_type::={ann_config}", 2725 update={f"data/{key}": pars[key]}, 2726 ) 2727 2728 def _initialize_project( 2729 self, 2730 data_type: str, 2731 annotation_type: str = None, 2732 data_path: str = None, 2733 annotation_path: str = None, 2734 copy: bool = True, 2735 ) -> None: 2736 """ 2737 Initialize a new project 2738 """ 2739 2740 if data_type not in self.data_types(): 2741 raise ValueError( 2742 f"The {data_type} data type is not available. " 2743 f"Please choose from {self.data_types()}" 2744 ) 2745 if annotation_type not in self.annotation_types(): 2746 raise ValueError( 2747 f"The {annotation_type} annotation type is not available. " 2748 f"Please choose from {self.annotation_types()}" 2749 ) 2750 os.mkdir(self.project_path) 2751 folders = ["results", "saved_datasets", "meta", "config"] 2752 for f in folders: 2753 os.mkdir(os.path.join(self.project_path, f)) 2754 results_subfolders = [ 2755 "model", 2756 "logs", 2757 "predictions", 2758 "splits", 2759 "searches", 2760 ] 2761 for sf in results_subfolders: 2762 os.mkdir(os.path.join(self.project_path, "results", sf)) 2763 if data_path is not None: 2764 if copy: 2765 os.mkdir(os.path.join(self.project_path, "data")) 2766 shutil.copytree( 2767 data_path, 2768 os.path.join(self.project_path, "data"), 2769 dirs_exist_ok=True, 2770 ) 2771 data_path = os.path.join(self.project_path, "data") 2772 if annotation_path is not None: 2773 if copy: 2774 os.mkdir(os.path.join(self.project_path, "annotation")) 2775 shutil.copytree( 2776 annotation_path, 2777 os.path.join(self.project_path, "annotation"), 2778 dirs_exist_ok=True, 2779 ) 2780 annotation_path = os.path.join(self.project_path, "annotation") 2781 self._generate_config( 2782 data_type, 2783 annotation_type, 2784 data_path=data_path, 2785 annotation_path=annotation_path, 2786 ) 2787 self._generate_meta() 2788 2789 def _read_types(self) -> Tuple[str, str]: 2790 """ 2791 Get data type and annotation type from existing project files 2792 """ 2793 2794 config_path = os.path.join(self.project_path, "config", "general.yaml") 2795 with open(config_path) as f: 2796 pars = YAML().load(f) 2797 data_type = pars["data_type"] 2798 annotation_type = pars["annotation_type"] 2799 return annotation_type, data_type 2800 2801 def _read_paths(self) -> Tuple[str, str]: 2802 """ 2803 Get data type and annotation type from existing project files 2804 """ 2805 2806 config_path = os.path.join(self.project_path, "config", "data.yaml") 2807 with open(config_path) as f: 2808 pars = YAML().load(f) 2809 data_path = pars["data_path"] 2810 annotation_path = pars["annotation_path"] 2811 return annotation_path, data_path 2812 2813 def _generate_config( 2814 self, data_type: str, annotation_type: str, data_path: str, annotation_path: str 2815 ) -> None: 2816 """ 2817 Initialize the config files 2818 """ 2819 2820 default_path = os.path.join( 2821 os.path.dirname(os.path.dirname(__file__)), "config" 2822 ) 2823 config_path = os.path.join(self.project_path, "config") 2824 files = ["losses", "metrics", "ssl", "training"] 2825 for f in files: 2826 shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path) 2827 shutil.copytree( 2828 os.path.join(default_path, "model"), os.path.join(config_path, "model") 2829 ) 2830 shutil.copytree( 2831 os.path.join(default_path, "features"), 2832 os.path.join(config_path, "features"), 2833 ) 2834 shutil.copytree( 2835 os.path.join(default_path, "augmentations"), 2836 os.path.join(config_path, "augmentations"), 2837 ) 2838 yaml = YAML() 2839 data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml") 2840 if os.path.exists(data_param_path): 2841 with open(data_param_path, encoding="utf-8") as f: 2842 data_params = yaml.load(f) 2843 if data_params is None: 2844 data_params = {} 2845 if annotation_type is None: 2846 ann_params = {} 2847 else: 2848 ann_param_path = os.path.join( 2849 default_path, "annotation", f"{annotation_type}.yaml" 2850 ) 2851 if os.path.exists(ann_param_path): 2852 ann_params = self._open_yaml(ann_param_path) 2853 elif annotation_type == "none": 2854 ann_params = {} 2855 else: 2856 raise ValueError( 2857 f"The {annotation_type} data type is not available. " 2858 f"Please choose from {BehaviorDataset.annotation_types()}" 2859 ) 2860 if ann_params is None: 2861 ann_params = {} 2862 data_params = self._update(data_params, ann_params) 2863 data_params["data_path"] = data_path 2864 data_params["annotation_path"] = annotation_path 2865 with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f: 2866 yaml.dump(data_params, f) 2867 with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f: 2868 general_params = yaml.load(f) 2869 general_params["data_type"] = data_type 2870 general_params["annotation_type"] = annotation_type 2871 with open(os.path.join(config_path, "general.yaml"), "w", encoding="utf-8") as f: 2872 yaml.dump(general_params, f) 2873 2874 def _generate_meta(self) -> None: 2875 """ 2876 Initialize the meta files 2877 """ 2878 2879 config_file = os.path.join(self.project_path, "config") 2880 meta_fields = ["time"] 2881 columns = [("meta", field) for field in meta_fields] 2882 episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 2883 episodes.to_pickle(self._episodes_path()) 2884 meta_fields = ["time", "objective"] 2885 result_fields = ["best_params", "best_value"] 2886 columns = [("meta", field) for field in meta_fields] + [ 2887 ("results", field) for field in result_fields 2888 ] 2889 searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 2890 searches.to_pickle(self._searches_path()) 2891 meta_fields = ["time"] 2892 columns = [("meta", field) for field in meta_fields] 2893 predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns)) 2894 predictions.to_pickle(self._predictions_path()) 2895 with open(os.path.join(config_file, "data.yaml")) as f: 2896 data_keys = list(YAML().load(f).keys()) 2897 saved_data = pd.DataFrame(columns=data_keys) 2898 saved_data.to_pickle(self._saved_datasets_path()) 2899 pd.DataFrame().to_pickle(self._thresholds_path()) 2900 # with open(self._version_path()) as f: 2901 # f.write(__version__) 2902 2903 def _open_yaml(self, path: str) -> CommentedMap: 2904 """ 2905 Load a parameter dictionary from a .yaml file 2906 """ 2907 2908 with open(path, encoding="utf-8") as f: 2909 data = YAML().load(f) 2910 if data is None: 2911 data = {} 2912 return data 2913 2914 def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7): 2915 """ 2916 Compare nested dictionaries with 'almost equal' condition 2917 """ 2918 2919 ok = True 2920 if u.keys() != d.keys(): 2921 ok = False 2922 else: 2923 for k, v in u.items(): 2924 if isinstance(v, Mapping): 2925 ok = self._compare(d[k], v, allow_diff=allow_diff) 2926 else: 2927 if isinstance(v, float) or isinstance(d[k], float): 2928 if not isinstance(d[k], float) and not isinstance(d[k], int): 2929 ok = False 2930 elif not isinstance(v, float) and not isinstance(v, int): 2931 ok = False 2932 elif np.abs(v - d[k]) > allow_diff: 2933 ok = False 2934 elif v != d[k]: 2935 ok = False 2936 return ok 2937 2938 def _check_comment(self, comment_sequence: List) -> bool: 2939 """ 2940 Check if a comment already exists in a ruamel.yaml comment sequence 2941 """ 2942 2943 if comment_sequence is None: 2944 return False 2945 c = self._get_comment(comment_sequence) 2946 if c != "": 2947 return True 2948 else: 2949 return False 2950 2951 def _get_comment(self, comment_sequence: List, strip=True) -> str: 2952 """ 2953 Get the comment string from a ruamel.yaml comment sequence 2954 """ 2955 2956 if comment_sequence is None: 2957 return "" 2958 c = "" 2959 for cm in comment_sequence: 2960 if cm is not None: 2961 if isinstance(cm, Iterable): 2962 for c in cm: 2963 if c is not None: 2964 c = c.value 2965 break 2966 break 2967 else: 2968 c = cm.value 2969 break 2970 if strip: 2971 c = c.strip() 2972 return c 2973 2974 def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]): 2975 """ 2976 Update a nested dictionary 2977 """ 2978 2979 if "general" in u and "model_name" in u["general"] and "model" in d: 2980 model_name = u["general"]["model_name"] 2981 if d["general"]["model_name"] != model_name: 2982 d["model"] = self._open_yaml( 2983 os.path.join( 2984 self.project_path, "config", "model", f"{model_name}.yaml" 2985 ) 2986 ) 2987 d_copied = deepcopy(d) 2988 for k, v in u.items(): 2989 if ( 2990 k in d_copied 2991 and isinstance(d_copied[k], list) 2992 and isinstance(v, Mapping) 2993 and all([isinstance(x, int) for x in v.keys()]) 2994 ): 2995 for kk, vv in v.items(): 2996 d_copied[k][kk] = vv 2997 elif ( 2998 isinstance(v, Mapping) 2999 and k in d_copied 3000 and isinstance(d_copied[k], Mapping) 3001 ): 3002 if d_copied[k] is None: 3003 d_k = CommentedMap() 3004 else: 3005 d_k = d_copied[k] 3006 d_copied[k] = self._update(d_k, v) 3007 else: 3008 d_copied[k] = v 3009 if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None: 3010 c = self._get_comment(u.ca.items.get(k), strip=False) 3011 if isinstance(d_copied, CommentedMap) and not self._check_comment( 3012 d_copied.ca.items.get(k) 3013 ): 3014 d_copied.yaml_add_eol_comment(c, key=k) 3015 return d_copied 3016 3017 def _update_with_search( 3018 self, 3019 d: Dict, 3020 search_name: str, 3021 load_parameters: list = None, 3022 round_to_binary: list = None, 3023 ): 3024 """ 3025 Update a dictionary with best parameters from a hyperparameter search 3026 """ 3027 3028 u, _ = self._searches().get_best_params( 3029 search_name, load_parameters, round_to_binary 3030 ) 3031 return self._update(d, u) 3032 3033 def _read_parameters(self, catch_blanks=True) -> Dict: 3034 """ 3035 Compose a parameter dictionary to create a task from the config files 3036 """ 3037 3038 config_path = os.path.join(self.project_path, "config") 3039 keys = [ 3040 "data", 3041 "general", 3042 "losses", 3043 "metrics", 3044 "ssl", 3045 "training", 3046 ] 3047 parameters = {} 3048 for key in keys: 3049 parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml")) 3050 features = parameters["general"]["feature_extraction"] 3051 parameters["features"] = self._open_yaml( 3052 os.path.join(config_path, "features", f"{features}.yaml") 3053 ) 3054 transformer = options.extractor_to_transformer[features] 3055 parameters["augmentations"] = self._open_yaml( 3056 os.path.join(config_path, "augmentations", f"{transformer}.yaml") 3057 ) 3058 model = parameters["general"]["model_name"] 3059 parameters["model"] = self._open_yaml( 3060 os.path.join(config_path, "model", f"{model}.yaml") 3061 ) 3062 # input = parameters["general"]["input"] 3063 # parameters["model"] = self._open_yaml( 3064 # os.path.join(config_path, "model", f"{model}.yaml") 3065 # ) 3066 if catch_blanks: 3067 blanks = self._get_blanks() 3068 if len(blanks) > 0: 3069 self.list_blanks() 3070 raise ValueError( 3071 f"Please fill in all the blanks before running experiments" 3072 ) 3073 return parameters 3074 3075 def set_main_parameters(self, model_name: str = None, metric_names: List = None): 3076 """ 3077 Select the model and the metrics 3078 3079 Parameters 3080 ---------- 3081 model_name : str, optional 3082 model name; run `project.help("model") to find out more 3083 metric_names : list, optional 3084 a list of metric function names; run `project.help("metrics") to find out more 3085 """ 3086 3087 pars = {"general": {}} 3088 if model_name is not None: 3089 assert model_name in options.models 3090 pars["general"]["model_name"] = model_name 3091 if metric_names is not None: 3092 for metric in metric_names: 3093 assert metric in options.metrics 3094 pars["general"]["metric_functions"] = metric_names 3095 self.update_parameters(pars) 3096 3097 def help(self, keyword: str = None): 3098 """ 3099 Get information on available options 3100 3101 Parameters 3102 ---------- 3103 keyword : str, optional 3104 the keyword for options (run without arguments to see which keywords are available) 3105 3106 """ 3107 3108 if keyword is None: 3109 print("AVAILABLE HELP FUNCTIONS:") 3110 print("- Try running `project.help(keyword)` with the following keywords:") 3111 print(" - model: to get more information on available models,") 3112 print( 3113 " - features: to get more information on available feature extraction modes," 3114 ) 3115 print( 3116 " - partition_method: to get more information on available train/test/val partitioning methods," 3117 ) 3118 print(" - metrics: to see a list of available metric functions.") 3119 print(" - data: to see help for expected data structure") 3120 print( 3121 "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in." 3122 ) 3123 print( 3124 "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify" 3125 ) 3126 print( 3127 f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)." 3128 ) 3129 elif keyword == "model": 3130 print("MODELS:") 3131 for key, model in options.models.items(): 3132 print(f"{key}:") 3133 print(model.__doc__) 3134 elif keyword == "features": 3135 print("FEATURE EXTRACTORS:") 3136 for key, extractor in options.feature_extractors.items(): 3137 print(f"{key}:") 3138 print(extractor.__doc__) 3139 elif keyword == "partition_method": 3140 print("PARTITION METHODS:") 3141 print( 3142 BehaviorDataset.partition_train_test_val.__doc__.split( 3143 "The partitioning method:" 3144 )[1].split("val_frac :")[0] 3145 ) 3146 elif keyword == "metrics": 3147 print("METRICS:") 3148 for key, metric in options.metrics.items(): 3149 print(f"{key}:") 3150 print(metric.__doc__) 3151 elif keyword == "data": 3152 print("DATA:") 3153 print(f"Video data: {self.data_type}") 3154 print(options.input_stores[self.data_type].__doc__) 3155 print(f"Annotation data: {self.annotation_type}") 3156 print(options.annotation_stores[self.annotation_type].__doc__) 3157 print( 3158 "Annotation path and data path don't have to be separate, you can keep everything in one folder." 3159 ) 3160 else: 3161 raise ValueError(f"The {keyword} keyword is not recognized") 3162 print("\n") 3163 3164 def _process_value(self, value): 3165 if isinstance(value, str): 3166 value = f'"{value}"' 3167 elif isinstance(value, CommentedSet): 3168 value = {x for x in value} 3169 return value 3170 3171 def _get_blanks(self): 3172 caught = [] 3173 parameters = self._read_parameters(catch_blanks=False) 3174 for big_key, big_value in parameters.items(): 3175 for key, value in big_value.items(): 3176 if value == "???": 3177 caught.append( 3178 (big_key, key, self._get_comment(big_value.ca.items.get(key))) 3179 ) 3180 return caught 3181 3182 def list_blanks(self, blanks=None): 3183 """ 3184 List parameters that need to be filled in 3185 3186 Parameters 3187 ---------- 3188 blanks : list, optional 3189 a list of the parameters to list, if already known 3190 """ 3191 3192 if blanks is None: 3193 blanks = self._get_blanks() 3194 if len(blanks) > 0: 3195 to_update = defaultdict(lambda: []) 3196 for b, k, c in blanks: 3197 to_update[b].append((k, c)) 3198 print("Before running experiments, please update all the blanks.") 3199 print("To do that, you can run this.") 3200 print("--------------------------------------------------------") 3201 print(f"project.update_parameters(") 3202 print(f" {{") 3203 for big_key, keys in to_update.items(): 3204 print(f' "{big_key}": {{') 3205 for key, comment in keys: 3206 print(f' "{key}": ..., {comment}') 3207 print(f" }},") 3208 print(f" }}") 3209 print(")") 3210 print("--------------------------------------------------------") 3211 print("Replace ... with relevant values.") 3212 else: 3213 print("There is no blanks left!") 3214 3215 def list_basic_parameters( 3216 self, 3217 ): 3218 """ 3219 Get a list of most relevant parameters and code to modify them 3220 """ 3221 3222 parameters = self._read_parameters() 3223 print("BASIC PARAMETERS:") 3224 model_name = parameters["general"]["model_name"] 3225 metric_names = parameters["general"]["metric_functions"] 3226 loss_name = parameters["general"]["loss_function"] 3227 feature_extraction = parameters["general"]["feature_extraction"] 3228 print("Here is a list of current parameters.") 3229 print( 3230 "You can copy this code, change the parameters you want to set and run it to update the project config." 3231 ) 3232 print("--------------------------------------------------------") 3233 print("project.update_parameters(") 3234 print(" {") 3235 for group in ["general", "data", "training"]: 3236 print(f' "{group}": {{') 3237 for key in options.basic_parameters[group]: 3238 if key in parameters[group]: 3239 print( 3240 f' "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}' 3241 ) 3242 print(" },") 3243 print(' "losses": {') 3244 print(f' "{loss_name}": {{') 3245 for key in options.basic_parameters["losses"][loss_name]: 3246 if key in parameters["losses"][loss_name]: 3247 print( 3248 f' "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}' 3249 ) 3250 print(" },") 3251 print(" },") 3252 print(' "metrics": {') 3253 for metric in metric_names: 3254 print(f' "{metric}": {{') 3255 for key in parameters["metrics"][metric]: 3256 print( 3257 f' "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}' 3258 ) 3259 print(" },") 3260 print(" },") 3261 print(' "model": {') 3262 for key in options.basic_parameters["model"][model_name]: 3263 if key in parameters["model"]: 3264 print( 3265 f' "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}' 3266 ) 3267 3268 print(" },") 3269 print(' "features": {') 3270 for key in options.basic_parameters["features"][feature_extraction]: 3271 if key in parameters["features"]: 3272 print( 3273 f' "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}' 3274 ) 3275 3276 print(" },") 3277 print(' "augmentations": {') 3278 for key in options.basic_parameters["augmentations"][feature_extraction]: 3279 if key in parameters["augmentations"]: 3280 print( 3281 f' "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}' 3282 ) 3283 print(" },") 3284 print(" },") 3285 print(")") 3286 print("--------------------------------------------------------") 3287 print("\n") 3288 3289 def _create_record( 3290 self, 3291 episode_name: str, 3292 behaviors_dict: Dict, 3293 load_episode: str = None, 3294 parameters_update: Dict = None, 3295 task: TaskDispatcher = None, 3296 load_epoch: int = None, 3297 load_search: str = None, 3298 load_parameters: list = None, 3299 round_to_binary: list = None, 3300 load_strict: bool = True, 3301 n_seeds: int = 1, 3302 ) -> TaskDispatcher: 3303 """ 3304 Create a meta data episode record 3305 """ 3306 3307 if episode_name in self._episodes().data.index: 3308 return 3309 if type(n_seeds) is not int or n_seeds < 1: 3310 raise ValueError( 3311 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 3312 ) 3313 if parameters_update is None: 3314 parameters_update = {} 3315 parameters = self._read_parameters() 3316 parameters = self._update(parameters, parameters_update) 3317 if load_search is not None: 3318 parameters = self._update_with_search( 3319 parameters, load_search, load_parameters, round_to_binary 3320 ) 3321 parameters = self._fill( 3322 parameters, 3323 episode_name, 3324 load_episode, 3325 load_epoch=load_epoch, 3326 only_load_model=True, 3327 load_strict=load_strict, 3328 continuing=True, 3329 ) 3330 self._save_episode(episode_name, parameters, behaviors_dict) 3331 return task 3332 3333 def _save_thresholds( 3334 self, 3335 episode_names: List, 3336 metric_name: str, 3337 parameters: Dict, 3338 thresholds: List, 3339 load_epochs: List, 3340 ): 3341 """ 3342 Save optimal decision thresholds in the meta records 3343 """ 3344 3345 metric_parameters = parameters["metrics"][metric_name] 3346 self._thresholds().save_thresholds( 3347 episode_names, load_epochs, metric_name, metric_parameters, thresholds 3348 ) 3349 3350 def _save_episode( 3351 self, 3352 episode_name: str, 3353 parameters: Dict, 3354 behaviors_dict: Dict, 3355 suppress_validation: bool = False, 3356 training_time: str = None, 3357 norm_stats: Dict = None, 3358 ) -> None: 3359 """ 3360 Save an episode in the meta files 3361 """ 3362 3363 try: 3364 split_info = self._split_info_from_filename( 3365 parameters["training"]["split_path"] 3366 ) 3367 parameters["training"]["partition_method"] = split_info["partition_method"] 3368 except: 3369 pass 3370 if norm_stats is not None: 3371 norm_stats = dict(norm_stats) 3372 parameters["training"]["stats"] = norm_stats 3373 self._episodes().save_episode( 3374 episode_name, 3375 parameters, 3376 behaviors_dict, 3377 suppress_validation=suppress_validation, 3378 training_time=training_time, 3379 ) 3380 3381 def _update_episode_results( 3382 self, 3383 episode_name: str, 3384 logs: Tuple, 3385 training_time: str = None, 3386 ) -> None: 3387 """ 3388 Save the results of a run in the meta files 3389 """ 3390 3391 self._episodes().update_episode_results(episode_name, logs, training_time) 3392 3393 def _save_prediction( 3394 self, 3395 episode_name: str, 3396 parameters: Dict, 3397 behaviors_dict: Dict, 3398 embedding: bool = False, 3399 inference_time: str = None, 3400 ) -> None: 3401 """ 3402 Save a prediction in the meta files 3403 """ 3404 3405 parameters = self._update( 3406 parameters, 3407 {"meta": {"embedding": embedding, "inference_time": inference_time}}, 3408 ) 3409 self._predictions().save_episode(episode_name, parameters, behaviors_dict) 3410 3411 def _save_search( 3412 self, 3413 search_name: str, 3414 parameters: Dict, 3415 n_trials: int, 3416 best_params: Dict, 3417 best_value: float, 3418 metric: str, 3419 search_space: Dict, 3420 ) -> None: 3421 """ 3422 Save a hyperparameter search in the meta files 3423 """ 3424 3425 self._searches().save_search( 3426 search_name, 3427 parameters, 3428 n_trials, 3429 best_params, 3430 best_value, 3431 metric, 3432 search_space, 3433 ) 3434 3435 def _save_stores(self, parameters: Dict) -> None: 3436 """ 3437 Save a pickled dataset in the meta files 3438 """ 3439 3440 name = os.path.basename(parameters["data"]["feature_save_path"]) 3441 self._saved_datasets().save_store(name, self._get_data_pars(parameters)) 3442 self.create_metadata_backup() 3443 3444 def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None: 3445 """ 3446 Remove the pre-computed features folder 3447 """ 3448 3449 name = os.path.basename(parameters["data"]["feature_save_path"]) 3450 if remove_active or name not in self._episodes().get_active_datasets(): 3451 self.remove_saved_features([name]) 3452 3453 def _check_episode_validity( 3454 self, episode_name: str, allow_doublecolon: bool = False, force: bool = False 3455 ) -> None: 3456 """ 3457 Check whether the episode name is valid 3458 """ 3459 3460 if episode_name.startswith("_"): 3461 raise ValueError( 3462 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 3463 ) 3464 elif "." in episode_name: 3465 raise ValueError("Names containing '.' cannot be used!") 3466 if not allow_doublecolon and "::" in episode_name: 3467 raise ValueError( 3468 "Names containing '::' are reserved by dlc2action and cannot be used!" 3469 ) 3470 if force: 3471 self.remove_episode(episode_name) 3472 elif not self._episodes().check_name_validity(episode_name): 3473 raise ValueError( 3474 f"The {episode_name} name is already taken! Set force=True to overwrite." 3475 ) 3476 3477 def _check_search_validity(self, search_name: str, force: bool = False) -> None: 3478 """ 3479 Check whether the search name is valid 3480 """ 3481 3482 if search_name.startswith("_"): 3483 raise ValueError( 3484 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 3485 ) 3486 elif "." in search_name: 3487 raise ValueError("Names containing '.' cannot be used!") 3488 if force: 3489 self.remove_search(search_name) 3490 elif not self._searches().check_name_validity(search_name): 3491 raise ValueError(f"The {search_name} name is already taken!") 3492 3493 def _check_prediction_validity( 3494 self, prediction_name: str, force: bool = False 3495 ) -> None: 3496 """ 3497 Check whether the prediction name is valid 3498 """ 3499 3500 if prediction_name.startswith("_"): 3501 raise ValueError( 3502 "Names starting with an underscore are reserved by dlc2action and cannot be used!" 3503 ) 3504 elif "." in prediction_name: 3505 raise ValueError("Names containing '.' cannot be used!") 3506 if force: 3507 self.remove_prediction(prediction_name) 3508 elif not self._predictions().check_name_validity(prediction_name): 3509 raise ValueError(f"The {prediction_name} name is already taken!") 3510 3511 def _training_time(self, episode_name: str) -> int: 3512 """ 3513 Get the training time of an episode in seconds 3514 """ 3515 3516 return self._episode(episode_name).training_time() 3517 3518 def _mask_path(self) -> str: 3519 """ 3520 Get the path to the masks folder 3521 """ 3522 3523 return os.path.join(self.project_path, "results", "masks") 3524 3525 def _thresholds_path(self) -> str: 3526 """ 3527 Get the path to the thresholds meta file 3528 """ 3529 3530 return os.path.join(self.project_path, "meta", "thresholds.pickle") 3531 3532 def _episodes_path(self) -> str: 3533 """ 3534 Get the path to the episodes meta file 3535 """ 3536 3537 return os.path.join(self.project_path, "meta", "episodes.pickle") 3538 3539 def _saved_datasets_path(self) -> str: 3540 """ 3541 Get the path to the datasets meta file 3542 """ 3543 3544 return os.path.join(self.project_path, "meta", "saved_datasets.pickle") 3545 3546 def _predictions_path(self) -> str: 3547 """ 3548 Get the path to the predictions meta file 3549 """ 3550 3551 return os.path.join(self.project_path, "meta", "predictions.pickle") 3552 3553 def _dataset_store_path(self, name: str) -> str: 3554 """ 3555 Get the path to a specific pickled dataset 3556 """ 3557 3558 return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle") 3559 3560 def _searches_path(self) -> str: 3561 """ 3562 Get the path to the hyperparameter search meta file 3563 """ 3564 3565 return os.path.join(self.project_path, "meta", "searches.pickle") 3566 3567 def _search_path(self, name: str) -> str: 3568 """ 3569 Get the default path to the graph folder for a specific hyperparameter search 3570 """ 3571 3572 return os.path.join(self.project_path, "results", "searches", name) 3573 3574 def _version_path(self) -> str: 3575 """ 3576 Get the path to the version file 3577 """ 3578 3579 return os.path.join(self.project_path, "meta", "version.txt") 3580 3581 def _default_split_file(self, split_info: Dict) -> Optional[str]: 3582 """ 3583 Generate a path to a split file from split parameters 3584 """ 3585 3586 if split_info["partition_method"].startswith("time"): 3587 return None 3588 val_frac = split_info["val_frac"] 3589 test_frac = split_info["test_frac"] 3590 split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}' 3591 if not split_info["only_load_annotated"]: 3592 split_name += "_all" 3593 split_name += ".txt" 3594 return os.path.join(self.project_path, "results", "splits", split_name) 3595 3596 def _split_info_from_filename(self, split_name: str) -> Dict: 3597 """ 3598 Get split parameters from default path to a split file 3599 """ 3600 3601 if split_name is None: 3602 return {} 3603 try: 3604 name = os.path.basename(split_name)[:-4] 3605 split = name.split("_") 3606 if len(split) == 6: 3607 only_load_annotated = False 3608 else: 3609 only_load_annotated = True 3610 len_segment = int(split[3][3:]) 3611 overlap = int(split[4][7:]) 3612 method, val, test = split[:3] 3613 val = float(val[3:-1]) / 100 3614 test = float(test[4:-1]) / 100 3615 return { 3616 "partition_method": method, 3617 "val_frac": val, 3618 "test_frac": test, 3619 "only_load_annotated": only_load_annotated, 3620 "len_segment": len_segment, 3621 "overlap": overlap, 3622 } 3623 except: 3624 return {"partition_method": "file"} 3625 3626 def _fill( 3627 self, 3628 parameters: Dict, 3629 episode_name: str, 3630 load_experiment: str = None, 3631 load_epoch: int = None, 3632 load_strict: bool = True, 3633 only_load_model: bool = False, 3634 continuing: bool = False, 3635 enforce_split_parameters: bool = False, 3636 ) -> Dict: 3637 """ 3638 Update the parameters from the config files with project specific information 3639 3640 Fill in the constant file path parameters and generate a unique log file and a model folder. 3641 Fill in the split file if the same split has been run before in the project and change partition method to 3642 from_file. 3643 Fill in saved data path if a dataset with the same data parameters already exists in the project. 3644 If load_experiment is not None, fill in the checkpoint path as well. 3645 The only_load_model training parameter is defined by the corresponding argument. 3646 If continuing is True, new files are not created and all information is loaded from load_experiment. 3647 If prediction is True, log and model files are not created. 3648 The enforce_split_parameters parameter is used to resolve conflicts 3649 between split file path and split parameters when they arise. 3650 """ 3651 3652 pars = deepcopy(parameters) 3653 if episode_name == "_": 3654 self.remove_episode("_") 3655 log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt") 3656 model_save_path = os.path.join( 3657 self.project_path, "results", "model", episode_name 3658 ) 3659 if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)): 3660 raise ValueError( 3661 f"The {episode_name} episode name is already in use! Set force=True to overwrite." 3662 ) 3663 keys = ["val_frac", "test_frac", "partition_method"] 3664 if "len_segment" not in pars["general"] and "len_segment" in pars["data"]: 3665 pars["general"]["len_segment"] = pars["data"]["len_segment"] 3666 if "overlap" not in pars["general"] and "overlap" in pars["data"]: 3667 pars["general"]["overlap"] = pars["data"]["overlap"] 3668 if "len_segment" in pars["data"]: 3669 pars["data"].pop("len_segment") 3670 if "overlap" in pars["data"]: 3671 pars["data"].pop("overlap") 3672 split_info = {k: pars["training"][k] for k in keys} 3673 split_info["only_load_annotated"] = pars["general"]["only_load_annotated"] 3674 split_info["len_segment"] = pars["general"]["len_segment"] 3675 split_info["overlap"] = pars["general"]["overlap"] 3676 pars["training"]["log_file"] = log 3677 if not os.path.exists(model_save_path): 3678 os.mkdir(model_save_path) 3679 pars["training"]["model_save_path"] = model_save_path 3680 if load_experiment is not None: 3681 if load_experiment not in self._episodes().data.index: 3682 raise ValueError(f"The {load_experiment} episode does not exist!") 3683 old_episode = self._episode(load_experiment) 3684 old_file = old_episode.split_file() 3685 old_info = self._split_info_from_filename(old_file) 3686 if len(old_info) == 0: 3687 old_info = old_episode.split_info() 3688 if enforce_split_parameters: 3689 if split_info["partition_method"] != "file": 3690 pars["training"]["split_path"] = self._default_split_file( 3691 split_info 3692 ) 3693 else: 3694 equal = True 3695 if old_info["partition_method"] != split_info["partition_method"]: 3696 equal = False 3697 if old_info["partition_method"] != "file": 3698 if ( 3699 old_info["val_frac"] != split_info["val_frac"] 3700 or old_info["test_frac"] != split_info["test_frac"] 3701 ): 3702 equal = False 3703 if not continuing and not equal: 3704 warnings.warn( 3705 f"The partitioning parameters in the loaded experiment ({old_info}) " 3706 f"are not equal to the current partitioning parameters ({split_info}). " 3707 f"The current parameters are replaced." 3708 ) 3709 pars["training"]["split_path"] = old_file 3710 pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch) 3711 pars["training"]["load_strict"] = load_strict 3712 else: 3713 pars["training"]["checkpoint_path"] = None 3714 if pars["training"]["partition_method"] == "file": 3715 if ( 3716 "split_path" not in pars["training"] 3717 or pars["training"]["split_path"] is None 3718 ): 3719 raise ValueError( 3720 "The partition_method parameter is set to file but the " 3721 "split_path parameter is not set!" 3722 ) 3723 elif not os.path.exists(pars["training"]["split_path"]): 3724 raise ValueError( 3725 f'The {pars["training"]["split_path"]} split file does not exist' 3726 ) 3727 else: 3728 pars["training"]["split_path"] = self._default_split_file(split_info) 3729 pars["training"]["only_load_model"] = only_load_model 3730 pars["data"]["saved_data_path"] = None 3731 pars["data"]["feature_save_path"] = None 3732 pars_data_copy = self._get_data_pars(pars) 3733 saved_data_name = self._saved_datasets().find_name(pars_data_copy) 3734 if saved_data_name is not None: 3735 pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name) 3736 pars["data"]["feature_save_path"] = self._dataset_store_path( 3737 saved_data_name 3738 ).split(".")[0] 3739 else: 3740 dataset_path = self._dataset_store_path(episode_name) 3741 if os.path.exists(dataset_path): 3742 name, ext = dataset_path.split(".") 3743 i = 0 3744 while os.path.exists(f"{name}_{i}.{ext}"): 3745 i += 1 3746 dataset_path = f"{name}_{i}.{ext}" 3747 pars["data"]["saved_data_path"] = dataset_path 3748 pars["data"]["feature_save_path"] = dataset_path.split(".")[0] 3749 split_split = pars["training"]["partition_method"].split(":") 3750 random = True 3751 for partition_method in options.partition_methods["fixed"]: 3752 method_split = partition_method.split(":") 3753 if len(split_split) != len(method_split): 3754 continue 3755 equal = True 3756 for x, y in zip(split_split, method_split): 3757 if y.startswith("{"): 3758 continue 3759 if x != y: 3760 equal = False 3761 break 3762 if equal: 3763 random = False 3764 break 3765 if random and os.path.exists(pars["training"]["split_path"]): 3766 pars["training"]["partition_method"] = "file" 3767 pars["general"]["save_dataset"] = True 3768 return pars 3769 3770 def _get_data_pars(self, pars: Dict) -> Dict: 3771 """ 3772 Get a complete description of the data from a general parameters dictionary 3773 """ 3774 3775 pars_data_copy = deepcopy(pars["data"]) 3776 for par in [ 3777 "only_load_annotated", 3778 "exclusive", 3779 "feature_extraction", 3780 "ignored_clips", 3781 "len_segment", 3782 "overlap", 3783 ]: 3784 pars_data_copy[par] = pars["general"].get(par, None) 3785 pars_data_copy.update(pars["features"]) 3786 return pars_data_copy 3787 3788 def count_classes( 3789 self, 3790 load_episode: str = None, 3791 parameters_update: Dict = None, 3792 remove_saved_features: bool = False, 3793 bouts: bool = True, 3794 ) -> Dict: 3795 """ 3796 Get a dictionary of class counts in different modes 3797 3798 Parameters 3799 ---------- 3800 load_episode : str, optional 3801 the episode settings to load 3802 parameters_update : dict, optional 3803 a dictionary of parameter updates (only for "data" and "general" categories) 3804 remove_saved_features : bool, default False 3805 if `True`, the dataset that is used for computation is then deleted 3806 bouts : bool, default False 3807 if `True`, instead of frame counts segment counts are returned 3808 3809 Returns 3810 ------- 3811 class_counts : dict 3812 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 3813 class names and values are class counts (in frames) 3814 """ 3815 3816 if load_episode is None: 3817 task, parameters = self._make_task_training( 3818 episode_name="_", parameters_update=parameters_update, throwaway=True 3819 ) 3820 else: 3821 task, parameters, _ = self._make_task_prediction( 3822 "_", 3823 load_episode=load_episode, 3824 parameters_update=parameters_update, 3825 ) 3826 class_counts = task.count_classes(bouts=bouts) 3827 behaviors = task.behaviors_dict() 3828 class_counts = { 3829 kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()} 3830 for kk, vv in class_counts.items() 3831 } 3832 if remove_saved_features: 3833 self._remove_stores(parameters) 3834 return class_counts 3835 3836 def plot_class_distribution( 3837 self, 3838 parameters_update: Dict = None, 3839 frame_cutoff: int = 1, 3840 bout_cutoff: int = 1, 3841 print_full: bool = False, 3842 remove_saved_features: bool = False, 3843 ) -> None: 3844 """ 3845 Make a class distribution plot 3846 3847 You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset 3848 is created or laoded for the computation with the default parameters). 3849 3850 Parameters 3851 ---------- 3852 parameters_update : dict, optional 3853 a dictionary of parameter updates (only for "data" and "general" categories) 3854 remove_saved_features : bool, default False 3855 if `True`, the dataset that is used for computation is then deleted 3856 """ 3857 3858 task, parameters = self._make_task_training( 3859 episode_name="_", parameters_update=parameters_update, throwaway=True 3860 ) 3861 cutoff = {True: bout_cutoff, False: frame_cutoff} 3862 for bouts in [True, False]: 3863 class_counts = task.count_classes(bouts=bouts) 3864 if print_full: 3865 print("Bouts:" if bouts else "Frames:") 3866 for k, v in class_counts.items(): 3867 if sum(v.values()) != 0: 3868 print(f" {k}:") 3869 values, keys = zip( 3870 *[ 3871 x 3872 for x in sorted(zip(v.values(), v.keys()), reverse=True) 3873 if x[-1] != -100 3874 ] 3875 ) 3876 for kk, vv in zip(keys, values): 3877 print(f" {task.behaviors_dict()[kk]}: {vv}") 3878 class_counts = { 3879 kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]} 3880 for kk, vv in class_counts.items() 3881 } 3882 for key, d in class_counts.items(): 3883 if sum(d.values()) != 0: 3884 values, keys = zip( 3885 *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100] 3886 ) 3887 keys = [task.behaviors_dict()[x] for x in keys] 3888 plt.bar(keys, values) 3889 plt.title(key) 3890 plt.xticks(rotation=45, ha="right") 3891 if bouts: 3892 plt.ylabel("bouts") 3893 else: 3894 plt.ylabel("frames") 3895 plt.tight_layout() 3896 plt.show() 3897 if remove_saved_features: 3898 self._remove_stores(parameters) 3899 3900 def _generate_mask( 3901 self, 3902 mask_name: str, 3903 perc_annotated: float = 0.1, 3904 parameters_update: Dict = None, 3905 remove_saved_features: bool = False, 3906 ) -> None: 3907 """ 3908 Generate a real_lens for active learning simulation 3909 3910 Parameters 3911 ---------- 3912 mask_name : str 3913 the name of the real_lens 3914 """ 3915 3916 print(f"GENERATING {mask_name}") 3917 task, parameters = self._make_task_training( 3918 f"_{mask_name}", parameters_update=parameters_update, throwaway=True 3919 ) 3920 val_intervals, val_ids = task.dataset("val").get_intervals() # 1 3921 unannotated_intervals = task.dataset("train").get_unannotated_intervals() # 2 3922 unannotated_intervals = task.dataset("val").get_unannotated_intervals( 3923 first_intervals=unannotated_intervals 3924 ) 3925 ids = task.dataset("train").get_ids() 3926 mask = {video_id: {} for video_id in ids} 3927 total_all = 0 3928 total_masked = 0 3929 for video_id, clip_ids in ids.items(): 3930 for clip_id in clip_ids: 3931 frames = np.ones(task.dataset("train").get_len(video_id, clip_id)) 3932 if clip_id in val_intervals[video_id]: 3933 for start, end in val_intervals[video_id][clip_id]: 3934 frames[start:end] = 0 3935 if clip_id in unannotated_intervals[video_id]: 3936 for start, end in unannotated_intervals[video_id][clip_id]: 3937 frames[start:end] = 0 3938 annotated = np.where(frames)[0] 3939 total_all += len(annotated) 3940 masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :] 3941 total_masked += len(masked) 3942 mask[video_id][clip_id] = self._get_intervals(masked) 3943 file = { 3944 "masked": mask, 3945 "val_intervals": val_intervals, 3946 "val_ids": val_ids, 3947 "unannotated": unannotated_intervals, 3948 } 3949 self._save_mask(file, mask_name) 3950 if remove_saved_features: 3951 self._remove_stores(parameters) 3952 print("\n") 3953 # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames') 3954 3955 def _get_intervals(self, frame_indices: np.ndarray): 3956 """ 3957 Get a list of intervals from a list of frame indices 3958 3959 Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`. 3960 3961 Parameters 3962 ---------- 3963 frame_indices : np.ndarray 3964 a list of frame indices 3965 3966 Returns 3967 ------- 3968 intervals : list 3969 a list of interval boundaries 3970 """ 3971 3972 masked_intervals = [] 3973 if len(frame_indices) > 0: 3974 breaks = np.where(np.diff(frame_indices) != 1)[0] 3975 start = frame_indices[0] 3976 for k in breaks: 3977 masked_intervals.append([start, frame_indices[k] + 1]) 3978 start = frame_indices[k + 1] 3979 masked_intervals.append([start, frame_indices[-1] + 1]) 3980 return masked_intervals 3981 3982 def _update_mask_with_uncertainty( 3983 self, 3984 mask_name: str, 3985 episode_name: Union[str, None], 3986 classes: List, 3987 load_epoch: int = None, 3988 n_frames: int = 10000, 3989 method: str = "least_confidence", 3990 min_length: int = 30, 3991 augment_n: int = 0, 3992 parameters_update: Dict = None, 3993 ): 3994 """ 3995 Update real_lens with frame-wise uncertainty scores for active learning 3996 3997 Parameters 3998 ---------- 3999 mask_name : str 4000 the name of the real_lens 4001 episode_name : str 4002 the name of the episode to load 4003 classes : list 4004 a list of class names or indices; their uncertainty scores will be computed separately and stacked 4005 n_frames : int, default 10000 4006 the number of frames to "annotate" 4007 method : {"least_confidence", "entropy"} 4008 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 4009 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 4010 min_length : int 4011 the minimum length (in frames) of the annotated intervals 4012 augment_n : int, default 0 4013 the number of augmentations to average over 4014 parameters_update : dict, optional 4015 the dictionary used to update the parameters from the config 4016 4017 Returns 4018 ------- 4019 score_dicts : dict 4020 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 4021 are score tensors 4022 """ 4023 4024 print(f"UPDATING {mask_name}") 4025 task, parameters, _ = self._make_task_prediction( 4026 prediction_name=mask_name, 4027 load_episode=episode_name, 4028 parameters_update=parameters_update, 4029 load_epoch=load_epoch, 4030 mode="train", 4031 ) 4032 score_tensors = task.generate_uncertainty_score(classes, augment_n, method) 4033 self._update_mask(task, mask_name, score_tensors, n_frames, min_length) 4034 print("\n") 4035 4036 def _update_mask_with_BALD( 4037 self, 4038 mask_name: str, 4039 episode_name: str, 4040 classes: List, 4041 load_epoch: int = None, 4042 augment_n: int = 0, 4043 n_frames: int = 10000, 4044 num_models: int = 10, 4045 kernel_size: int = 11, 4046 min_length: int = 30, 4047 parameters_update: Dict = None, 4048 ): 4049 """ 4050 Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning 4051 4052 Parameters 4053 ---------- 4054 mask_name : str 4055 the name of the real_lens 4056 episode_name : str 4057 the name of the episode to load 4058 classes : list 4059 a list of class names or indices; their uncertainty scores will be computed separately and stacked 4060 augment_n : int, default 0 4061 the number of augmentations to average over 4062 n_frames : int, default 10000 4063 the number of frames to "annotate" 4064 num_models : int, default 10 4065 the number of dropout masks to apply 4066 kernel_size : int, default 11 4067 the size of the smoothing gaussian kernel 4068 min_length : int 4069 the minimum length (in frames) of the annotated intervals 4070 parameters_update : dict, optional 4071 the dictionary used to update the parameters from the config 4072 4073 Returns 4074 ------- 4075 score_dicts : dict 4076 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 4077 are score tensors 4078 """ 4079 4080 print(f"UPDATING {mask_name}") 4081 task, parameters, mode = self._make_task_prediction( 4082 mask_name, 4083 load_episode=episode_name, 4084 parameters_update=parameters_update, 4085 load_epoch=load_epoch, 4086 ) 4087 score_tensors = task.generate_bald_score( 4088 classes, augment_n, num_models, kernel_size 4089 ) 4090 self._update_mask(task, mask_name, score_tensors, n_frames, min_length) 4091 print("\n") 4092 4093 def _suggest_intervals( 4094 self, 4095 dataset: BehaviorDataset, 4096 score_tensors: Dict, 4097 n_frames: int, 4098 min_length: int, 4099 ) -> Dict: 4100 """ 4101 Suggest intervals with highest score of total length `n_frames` 4102 4103 Parameters 4104 ---------- 4105 dataset : BehaviorDataset 4106 the dataset 4107 score_tensors : dict 4108 a dictionary where keys are clip ids and values are framewise score tensors 4109 n_frames : int 4110 the number of frames to "annotate" 4111 min_length : int 4112 4113 Returns 4114 ------- 4115 active_learning_intervals : Dict 4116 active learning dictionary with suggested intervals 4117 """ 4118 4119 video_intervals, _ = dataset.get_intervals() 4120 taken = { 4121 video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys() 4122 } 4123 annotated = dataset.get_annotated_intervals() 4124 for video_id in video_intervals: 4125 for clip_id in video_intervals[video_id]: 4126 taken[video_id][clip_id] = torch.zeros( 4127 dataset.get_len(video_id, clip_id) 4128 ) 4129 if video_id in annotated and clip_id in annotated[video_id]: 4130 for start, end in annotated[video_id][clip_id]: 4131 score_tensors[video_id][clip_id][:, start:end] = -10 4132 taken[video_id][clip_id][int(start) : int(end)] = 1 4133 n_frames = ( 4134 sum([(vv == 1).sum() for v in taken.values() for vv in v.values()]) 4135 + n_frames 4136 ) 4137 factor = 1 4138 threshold_start = float( 4139 torch.mean( 4140 torch.tensor( 4141 [ 4142 torch.mean( 4143 torch.tensor([torch.mean(y[y > 0]) for y in x.values()]) 4144 ) 4145 for x in score_tensors.values() 4146 ] 4147 ) 4148 ) 4149 ) 4150 while ( 4151 sum([(vv == 1).sum() for v in taken.values() for vv in v.values()]) 4152 < n_frames 4153 ): 4154 threshold = threshold_start * factor 4155 intervals = [] 4156 interval_scores = [] 4157 key1 = list(score_tensors.keys())[0] 4158 key2 = list(score_tensors[key1].keys())[0] 4159 num_scores = score_tensors[key1][key2].shape[0] 4160 for i in range(num_scores): 4161 v_dict = dataset.find_valleys( 4162 predicted=score_tensors, 4163 threshold=threshold, 4164 min_frames=min_length, 4165 main_class=i, 4166 low=False, 4167 ) 4168 for v_id, interval_list in v_dict.items(): 4169 intervals += [x + [v_id] for x in interval_list] 4170 interval_scores += [ 4171 float(torch.mean(score_tensors[v_id][clip_id][i, start:end])) 4172 for start, end, clip_id in interval_list 4173 ] 4174 intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]] 4175 i = 0 4176 while sum( 4177 [(vv == 1).sum() for v in taken.values() for vv in v.values()] 4178 ) < n_frames and i < len(intervals): 4179 start, end, clip_id, video_id = intervals[i] 4180 i += 1 4181 taken[video_id][clip_id][int(start) : int(end)] = 1 4182 factor *= 0.9 4183 if factor < 0.05: 4184 warnings.warn(f"Could not find enough frames!") 4185 break 4186 active_learning_intervals = {video_id: [] for video_id in video_intervals} 4187 for video_id in taken: 4188 for clip_id in taken[video_id]: 4189 if video_id in annotated and clip_id in annotated[video_id]: 4190 for start, end in annotated[video_id][clip_id]: 4191 taken[video_id][clip_id][int(start) : int(end)] = 0 4192 if (taken[video_id][clip_id] == 1).sum() == 0: 4193 continue 4194 indices = np.where(taken[video_id][clip_id].numpy())[0] 4195 boundaries = self._get_intervals(indices) 4196 active_learning_intervals[video_id] += [ 4197 [start, end, clip_id] for start, end in boundaries 4198 ] 4199 return active_learning_intervals 4200 4201 def _update_mask( 4202 self, 4203 task: TaskDispatcher, 4204 mask_name: str, 4205 score_tensors: Dict, 4206 n_frames: int, 4207 min_length: int, 4208 ) -> None: 4209 """ 4210 Update the real_lens with intervals with the highest score of total length `n_frames` 4211 4212 Parameters 4213 ---------- 4214 mask_name : str 4215 the name of the real_lens 4216 score_tensors : dict 4217 a dictionary where keys are clip ids and values are framewise score tensors 4218 n_frames : int 4219 the number of frames to "annotate" 4220 min_length : int 4221 the minimum length of the annotated intervals 4222 """ 4223 4224 mask = self._load_mask(mask_name) 4225 video_intervals, _ = task.dataset("train").get_intervals() 4226 masked = { 4227 video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys() 4228 } 4229 total_masked = 0 4230 total_all = 0 4231 for video_id in video_intervals: 4232 for clip_id in video_intervals[video_id]: 4233 masked[video_id][clip_id] = torch.zeros( 4234 task.dataset("train").get_len(video_id, clip_id) 4235 ) 4236 if ( 4237 video_id in mask["unannotated"] 4238 and clip_id in mask["unannotated"][video_id] 4239 ): 4240 for start, end in mask["unannotated"][video_id][clip_id]: 4241 score_tensors[video_id][clip_id][:, start:end] = -10 4242 masked[video_id][clip_id][int(start) : int(end)] = 1 4243 if ( 4244 video_id in mask["val_intervals"] 4245 and clip_id in mask["val_intervals"][video_id] 4246 ): 4247 for start, end in mask["val_intervals"][video_id][clip_id]: 4248 score_tensors[video_id][clip_id][:, start:end] = -10 4249 masked[video_id][clip_id][int(start) : int(end)] = 1 4250 total_all += torch.sum(masked[video_id][clip_id] == 0) 4251 if video_id in mask["masked"] and clip_id in mask["masked"][video_id]: 4252 # print(f'{real_lens["masked"][video_id][clip_id]=}') 4253 for start, end in mask["masked"][video_id][clip_id]: 4254 masked[video_id][clip_id][int(start) : int(end)] = 1 4255 total_masked += end - start 4256 old_n_frames = sum( 4257 [(vv == 0).sum() for v in masked.values() for vv in v.values()] 4258 ) 4259 n_frames = old_n_frames + n_frames 4260 factor = 1 4261 while ( 4262 sum([(vv == 0).sum() for v in masked.values() for vv in v.values()]) 4263 < n_frames 4264 ): 4265 threshold = float( 4266 torch.mean( 4267 torch.tensor( 4268 [ 4269 torch.mean( 4270 torch.tensor([torch.mean(y[y > 0]) for y in x.values()]) 4271 ) 4272 for x in score_tensors.values() 4273 ] 4274 ) 4275 ) 4276 ) 4277 threshold = threshold * factor 4278 intervals = [] 4279 interval_scores = [] 4280 key1 = list(score_tensors.keys())[0] 4281 key2 = list(score_tensors[key1].keys())[0] 4282 num_scores = score_tensors[key1][key2].shape[0] 4283 for i in range(num_scores): 4284 v_dict = task.dataset("train").find_valleys( 4285 predicted=score_tensors, 4286 threshold=threshold, 4287 min_frames=min_length, 4288 main_class=i, 4289 low=False, 4290 ) 4291 for v_id, interval_list in v_dict.items(): 4292 intervals += [x + [v_id] for x in interval_list] 4293 interval_scores += [ 4294 float(torch.mean(score_tensors[v_id][clip_id][i, start:end])) 4295 for start, end, clip_id in interval_list 4296 ] 4297 intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]] 4298 i = 0 4299 while sum( 4300 [(vv == 0).sum() for v in masked.values() for vv in v.values()] 4301 ) < n_frames and i < len(intervals): 4302 start, end, clip_id, video_id = intervals[i] 4303 i += 1 4304 masked[video_id][clip_id][int(start) : int(end)] = 0 4305 factor *= 0.9 4306 if factor < 0.05: 4307 warnings.warn(f"Could not find enough frames!") 4308 break 4309 mask["masked"] = {video_id: {} for video_id in video_intervals} 4310 total_masked_new = 0 4311 for video_id in masked: 4312 for clip_id in masked[video_id]: 4313 if ( 4314 video_id in mask["unannotated"] 4315 and clip_id in mask["unannotated"][video_id] 4316 ): 4317 for start, end in mask["unannotated"][video_id][clip_id]: 4318 masked[video_id][clip_id][int(start) : int(end)] = 0 4319 if ( 4320 video_id in mask["val_intervals"] 4321 and clip_id in mask["val_intervals"][video_id] 4322 ): 4323 for start, end in mask["val_intervals"][video_id][clip_id]: 4324 masked[video_id][clip_id][int(start) : int(end)] = 0 4325 indices = np.where(masked[video_id][clip_id].numpy())[0] 4326 mask["masked"][video_id][clip_id] = self._get_intervals(indices) 4327 for video_id in mask["masked"]: 4328 for clip_id in mask["masked"][video_id]: 4329 for start, end in mask["masked"][video_id][clip_id]: 4330 total_masked_new += end - start 4331 self._save_mask(mask, mask_name) 4332 with open( 4333 os.path.join(self.project_path, "results", f"{mask_name}.txt"), "a" 4334 ) as f: 4335 f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n") 4336 print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}") 4337 4338 def plot_confusion_matrix( 4339 self, 4340 episode_name: str, 4341 load_epoch: int = None, 4342 parameters_update: Dict = None, 4343 type: str = "recall", 4344 mode: str = "val", 4345 remove_saved_features: bool = False, 4346 ) -> Tuple[ndarray, Iterable]: 4347 """ 4348 Make a confusion matrix plot and return the data 4349 4350 If the annotation is non-exclusive, only false positive labels are considered. 4351 4352 Parameters 4353 ---------- 4354 episode_name : str 4355 the name of the episode to load 4356 load_epoch : int, optional 4357 the index of the epoch to load (by default the last one is loaded) 4358 parameters_update : dict, optional 4359 a dictionary of parameter updates (only for "data" and "general" categories) 4360 mode : {'val', 'all', 'test', 'train'} 4361 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 4362 type : {"recall", "precision"} 4363 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 4364 into account, and if `type` is `"precision"`, only false negatives 4365 remove_saved_features : bool, default False 4366 if `True`, the dataset that is used for computation is then deleted 4367 4368 Returns 4369 ------- 4370 confusion_matrix : np.ndarray 4371 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 4372 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 4373 `N_i` is the number of frames that have the i-th label in the ground truth 4374 classes : list 4375 a list of labels 4376 """ 4377 4378 task, parameters, mode = self._make_task_prediction( 4379 "_", 4380 load_episode=episode_name, 4381 load_epoch=load_epoch, 4382 parameters_update=parameters_update, 4383 mode=mode, 4384 ) 4385 dataset = task.dataset(mode) 4386 prediction = task.predict(dataset, raw_output=True) 4387 confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type) 4388 if remove_saved_features: 4389 self._remove_stores(parameters) 4390 fig, ax = plt.subplots(figsize=(len(classes), len(classes))) 4391 ax.imshow(confusion_matrix) 4392 # Show all ticks and label them with the respective list entries 4393 ax.set_xticks(np.arange(len(classes))) 4394 ax.set_xticklabels(classes) 4395 ax.set_yticks(np.arange(len(classes))) 4396 ax.set_yticklabels(classes) 4397 # Rotate the tick labels and set their alignment. 4398 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 4399 # Loop over data dimensions and create text annotations. 4400 for i in range(len(classes)): 4401 for j in range(len(classes)): 4402 ax.text( 4403 j, 4404 i, 4405 np.round(confusion_matrix[i, j], 2), 4406 ha="center", 4407 va="center", 4408 color="w", 4409 ) 4410 if type is not None: 4411 ax.set_title(f"{type} {episode_name}") 4412 else: 4413 ax.set_title(episode_name) 4414 fig.tight_layout() 4415 plt.show() 4416 return confusion_matrix, classes 4417 4418 def plot_predictions( 4419 self, 4420 episode_name: str, 4421 load_epoch: int = None, 4422 parameters_update: Dict = None, 4423 add_legend: bool = True, 4424 ground_truth: bool = True, 4425 colormap: str = "viridis", 4426 hide_axes: bool = False, 4427 min_classes: int = 1, 4428 width: float = 10, 4429 whole_video: bool = False, 4430 transparent: bool = False, 4431 drop_classes: Set = None, 4432 search_classes: Set = None, 4433 num_plots: int = 1, 4434 remove_saved_features: bool = False, 4435 smooth_interval_prediction: int = 0, 4436 data_path: str = None, 4437 file_paths: Set = None, 4438 mode: str = "val", 4439 behavior_name: str = None, 4440 ) -> None: 4441 """ 4442 Visualize random predictions 4443 4444 Parameters 4445 ---------- 4446 episode_name : str 4447 the name of the episode to load 4448 load_epoch : int, optional 4449 the epoch to load (by default last) 4450 parameters_update : dict, optional 4451 parameter update dictionary 4452 add_legend : bool, default True 4453 if True, legend will be added to the plot 4454 ground_truth : bool, default True 4455 if True, ground truth will be added to the plot 4456 colormap : str, default 'Accent' 4457 the `matplotlib` colormap to use 4458 hide_axes : bool, default True 4459 if `True`, the axes will be hidden on the plot 4460 min_classes : int, default 1 4461 the minimum number of classes in a displayed interval 4462 width : float, default 10 4463 the width of the plot 4464 whole_video : bool, default False 4465 if `True`, whole videos are plotted instead of segments 4466 transparent : bool, default False 4467 if `True`, the background on the plot is transparent 4468 drop_classes : set, optional 4469 a set of class names to not be displayed 4470 search_classes : set, optional 4471 if given, only intervals where at least one of the classes is in ground truth will be shown 4472 num_plots : int, default 1 4473 the number of plots to make 4474 remove_saved_features : bool, default False 4475 if `True`, the dataset will be deleted after computation 4476 smooth_interval_prediction : int, default 0 4477 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) 4478 data_path : str, optional 4479 the data path to run the prediction for 4480 mode : {'all', 'test', 'val', 'train'} 4481 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 4482 file_paths : set, optional 4483 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 4484 for 4485 behavior_name : str, optional 4486 for non-exclusive classificaton datasets, choose which behavior to visualize (by default first in list) 4487 """ 4488 4489 other_path = os.path.join(self.project_path, "results", "other") 4490 task, parameters, mode = self._make_task_prediction( 4491 "_", 4492 load_episode=episode_name, 4493 parameters_update=parameters_update, 4494 load_epoch=load_epoch, 4495 data_path=data_path, 4496 file_paths=file_paths, 4497 mode=mode, 4498 ) 4499 if not os.path.exists(other_path): 4500 os.mkdir(other_path) 4501 for i in range(num_plots): 4502 task.visualize_results( 4503 save_path=os.path.join( 4504 other_path, f"{episode_name}_prediction_{i}.jpg" 4505 ), 4506 add_legend=add_legend, 4507 ground_truth=ground_truth, 4508 colormap=colormap, 4509 hide_axes=hide_axes, 4510 min_classes=min_classes, 4511 whole_video=whole_video, 4512 transparent=transparent, 4513 dataset=mode, 4514 drop_classes=drop_classes, 4515 search_classes=search_classes, 4516 width=width, 4517 smooth_interval_prediction=smooth_interval_prediction, 4518 behavior_name=behavior_name, 4519 ) 4520 if remove_saved_features: 4521 self._remove_stores(parameters) 4522 4523 def create_metadata_backup(self) -> None: 4524 """ 4525 Create a copy of the meta files 4526 """ 4527 4528 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 4529 meta_path = os.path.join(self.project_path, "meta") 4530 if os.path.exists(meta_copy_path): 4531 shutil.rmtree(meta_copy_path) 4532 os.mkdir(meta_copy_path) 4533 for file in os.listdir(meta_path): 4534 if file == "backup": 4535 continue 4536 shutil.copy( 4537 os.path.join(meta_path, file), os.path.join(meta_copy_path, file) 4538 ) 4539 4540 def load_metadata_backup(self) -> None: 4541 """ 4542 Load from previously created meta data backup (in case of corruption) 4543 """ 4544 4545 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 4546 meta_path = os.path.join(self.project_path, "meta") 4547 for file in os.listdir(meta_copy_path): 4548 shutil.copy( 4549 os.path.join(meta_copy_path, file), os.path.join(meta_path, file) 4550 ) 4551 4552 def get_behavior_dictionary(self, episode_name: str) -> Dict: 4553 """ 4554 Get the behavior dictionary for an episode 4555 4556 Parameters 4557 ---------- 4558 episode_name : str 4559 the name of the episode 4560 4561 Returns 4562 ------- 4563 behaviors_dictionary : dict 4564 a dictionary where keys are label indices and values are label names 4565 """ 4566 4567 run = self._episodes().get_runs(episode_name)[0] 4568 return self._episode(run).get_behaviors_dict() 4569 4570 def import_episodes( 4571 self, 4572 episodes_directory: str, 4573 name_map: Dict = None, 4574 repeat_policy: str = "error", 4575 ) -> None: 4576 """ 4577 Import episodes exported with `Project.export_episodes` 4578 4579 Parameters 4580 ---------- 4581 episodes_directory : str 4582 the path to the exported episodes directory 4583 name_map : dict 4584 a name change dictionary for the episodes: keys are old names, values are new names 4585 """ 4586 4587 if name_map is None: 4588 name_map = {} 4589 episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle")) 4590 to_remove = [] 4591 import_string = "Imported episodes: " 4592 for episode_name in episodes.index: 4593 if episode_name in name_map: 4594 import_string += f"{episode_name} " 4595 episode_name = name_map[episode_name] 4596 import_string += f"({episode_name}), " 4597 else: 4598 import_string += f"{episode_name}, " 4599 try: 4600 self._check_episode_validity(episode_name, allow_doublecolon=True) 4601 except ValueError as e: 4602 if str(e).endswith("is already taken!"): 4603 if repeat_policy == "skip": 4604 to_remove.append(episode_name) 4605 elif repeat_policy == "force": 4606 self.remove_episode(episode_name) 4607 elif repeat_policy == "error": 4608 raise ValueError( 4609 f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it" 4610 ) 4611 else: 4612 raise ValueError( 4613 f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' ans 'error']" 4614 ) 4615 episodes = episodes.drop(index=to_remove) 4616 self._episodes().update( 4617 episodes, 4618 name_map=name_map, 4619 force=(repeat_policy == "force"), 4620 data_path=self.data_path, 4621 annotation_path=self.annotation_path, 4622 ) 4623 for episode_name in episodes.index: 4624 if episode_name in name_map: 4625 new_episode_name = name_map[episode_name] 4626 else: 4627 new_episode_name = episode_name 4628 model_dir = os.path.join( 4629 self.project_path, "results", "model", new_episode_name 4630 ) 4631 old_model_dir = os.path.join(episodes_directory, "model", episode_name) 4632 if os.path.exists(model_dir): 4633 shutil.rmtree(model_dir) 4634 os.mkdir(model_dir) 4635 for file in os.listdir(old_model_dir): 4636 shutil.copyfile( 4637 os.path.join(old_model_dir, file), os.path.join(model_dir, file) 4638 ) 4639 log_file = os.path.join( 4640 self.project_path, "results", "logs", f"{new_episode_name}.txt" 4641 ) 4642 old_log_file = os.path.join( 4643 episodes_directory, "logs", f"{episode_name}.txt" 4644 ) 4645 shutil.copyfile(old_log_file, log_file) 4646 print(import_string) 4647 print("\n") 4648 4649 def export_episodes( 4650 self, episode_names: List, output_directory: str, name: str = None 4651 ) -> None: 4652 """ 4653 Save selected episodes as a file that can be imported into another project with `Project.import_episodes` 4654 4655 Parameters 4656 ---------- 4657 episode_names : list 4658 a list of string episode names 4659 output_directory : str 4660 the path to the directory where the episodes will be saved 4661 name : str, optional 4662 the name of the episodes directory (by default `exported_episodes`) 4663 """ 4664 4665 if name is None: 4666 name = "exported_episodes" 4667 if os.path.exists( 4668 os.path.join(output_directory, name + ".zip") 4669 ) or os.path.exists(os.path.join(output_directory, name)): 4670 i = 1 4671 while os.path.exists( 4672 os.path.join(output_directory, name + f"_{i}.zip") 4673 ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")): 4674 i += 1 4675 name = name + f"_{i}" 4676 dest_dir = os.path.join(output_directory, name) 4677 os.mkdir(dest_dir) 4678 os.mkdir(os.path.join(dest_dir, "model")) 4679 os.mkdir(os.path.join(dest_dir, "logs")) 4680 runs = [] 4681 for episode in episode_names: 4682 runs += self._episodes().get_runs(episode) 4683 for run in runs: 4684 shutil.copytree( 4685 os.path.join(self.project_path, "results", "model", run), 4686 os.path.join(dest_dir, "model", run), 4687 ) 4688 shutil.copyfile( 4689 os.path.join(self.project_path, "results", "logs", f"{run}.txt"), 4690 os.path.join(dest_dir, "logs", f"{run}.txt"), 4691 ) 4692 data = self._episodes().get_subset(runs) 4693 data.to_pickle(os.path.join(dest_dir, "episodes.pickle")) 4694 4695 def get_results_table( 4696 self, 4697 episode_names: List, 4698 metrics: List = None, 4699 include_std: bool = False, 4700 classes: List = None, 4701 ): 4702 """ 4703 Genererate a `pandas` dataframe with a summary of episode results 4704 4705 Parameters 4706 ---------- 4707 episode_names : list 4708 a list of names of episodes to include 4709 metrics : list, optional 4710 a list of metric names to include 4711 include_std : bool, default False 4712 if `True`, for episodes with multiple runs the mean and standard deviation will be displayed; 4713 otherwise only mean 4714 classes : list, optional 4715 a list of names of classes to include (by default all are included) 4716 4717 Returns 4718 ------- 4719 results : pd.DataFrame 4720 a table with the results 4721 """ 4722 4723 run_names = [] 4724 for episode in episode_names: 4725 run_names += self._episodes().get_runs(episode) 4726 episodes = self.list_episodes(run_names, print_results=False) 4727 metric_columns = [x for x in episodes.columns if x[0] == "results"] 4728 results_df = pd.DataFrame() 4729 if metrics is not None: 4730 metric_columns = [ 4731 x for x in metric_columns if x[1].split("_")[0] in metrics 4732 ] 4733 for episode in episode_names: 4734 results = [] 4735 metric_set = set() 4736 for run in self._episodes().get_runs(episode): 4737 beh_dict = self.get_behavior_dictionary(run) 4738 res_dict = defaultdict(lambda: {}) 4739 for column in metric_columns: 4740 if np.isnan(episodes.loc[run, column]): 4741 continue 4742 split = column[1].split("_") 4743 if split[-1].isnumeric(): 4744 beh_ind = int(split[-1]) 4745 metric_name = "_".join(split[:-1]) 4746 beh = beh_dict[beh_ind] 4747 else: 4748 beh = "average" 4749 metric_name = column[1] 4750 res_dict[beh][metric_name] = episodes.loc[run, column] 4751 metric_set.add(metric_name) 4752 if "average" not in res_dict: 4753 res_dict["average"] = {} 4754 for metric in metric_set: 4755 if metric not in res_dict["average"]: 4756 arr = [ 4757 res_dict[beh][metric] 4758 for beh in res_dict 4759 if metric in res_dict[beh] 4760 ] 4761 res_dict["average"][metric] = np.mean(arr) 4762 results.append(res_dict) 4763 episode_results = {} 4764 for metric in metric_set: 4765 for beh in results[0].keys(): 4766 if classes is not None and beh not in classes: 4767 continue 4768 arr = [] 4769 for res_dict in results: 4770 if metric in res_dict[beh]: 4771 arr.append(res_dict[beh][metric]) 4772 if len(arr) > 0: 4773 if include_std: 4774 episode_results[ 4775 (beh, f"{episode} {metric} mean") 4776 ] = np.mean(arr) 4777 episode_results[(beh, f"{episode} {metric} std")] = np.std( 4778 arr 4779 ) 4780 else: 4781 episode_results[(beh, f"{episode} {metric}")] = np.mean(arr) 4782 for key, value in episode_results.items(): 4783 results_df.loc[key[0], key[1]] = value 4784 print(f"RESULTS:") 4785 print(results_df) 4786 print("\n") 4787 return results_df 4788 4789 def episode_exists(self, episode_name: str) -> bool: 4790 """ 4791 Check if an episode already exists 4792 4793 Parameters 4794 ---------- 4795 episode_name : str 4796 the episode name 4797 4798 Returns 4799 ------- 4800 exists : bool 4801 `True` if the episode exists 4802 """ 4803 4804 return self._episodes().check_name_validity(episode_name) 4805 4806 def search_exists(self, search_name: str) -> bool: 4807 """ 4808 Check if a search already exists 4809 4810 Parameters 4811 ---------- 4812 search_name : str 4813 the search name 4814 4815 Returns 4816 ------- 4817 exists : bool 4818 `True` if the search exists 4819 """ 4820 4821 return self._searches().check_name_validity(search_name) 4822 4823 def prediction_exists(self, prediction_name: str) -> bool: 4824 """ 4825 Check if a prediction already exists 4826 4827 Parameters 4828 ---------- 4829 prediction_name : str 4830 the prediction name 4831 4832 Returns 4833 ------- 4834 exists : bool 4835 `True` if the prediction exists 4836 """ 4837 4838 return self._predictions().check_name_validity(prediction_name) 4839 4840 @staticmethod 4841 def project_name_available(projects_path: str, project_name: str): 4842 if projects_path is None: 4843 projects_path = os.path.join(str(Path.home()), "DLC2Action") 4844 return not os.path.exists(os.path.join(projects_path, project_name)) 4845 4846 def _update_episode_metrics(self, episode_name: str, metrics: Dict): 4847 """ 4848 Update meta data with evaluation results 4849 """ 4850 4851 self._episodes().update_episode_metrics(episode_name, metrics) 4852 4853 def rename_episode(self, episode_name: str, new_episode_name: str): 4854 shutil.move( 4855 os.path.join(self.project_path, "results", "model", episode_name), 4856 os.path.join(self.project_path, "results", "model", new_episode_name), 4857 ) 4858 shutil.move( 4859 os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"), 4860 os.path.join( 4861 self.project_path, "results", "logs", f"{new_episode_name}.txt" 4862 ), 4863 ) 4864 self._episodes().rename_episode(episode_name, new_episode_name)
A class to create and maintain the project files + keep track of experiments
56 def __init__( 57 self, 58 name: str, 59 data_type: str = None, 60 annotation_type: str = "none", 61 projects_path: str = None, 62 data_path: Union[str, List] = None, 63 annotation_path: Union[str, List] = None, 64 copy: bool = False, 65 ) -> None: 66 """ 67 Parameters 68 ---------- 69 name : str 70 name of the project 71 data_type : str, optional 72 data type (run Project.data_types() to see available options; has to be provided if the project is being 73 created) 74 annotation_type : str, default 'none' 75 annotation type (run Project.annotation_types() to see available options) 76 projects_path : str, optional 77 path to the projects folder (is filled with ~/DLC2Action by default) 78 data_path : str, optional 79 path to the folder containing input files for the project (has to be provided if the project is being 80 created) 81 annotation_path : str, optional 82 path to the folder containing annotation files for the project 83 copy : bool, default False 84 if True, the files from annotation_path and data_path will be copied to the projects folder; 85 otherwise they will be moved 86 """ 87 88 if projects_path is None: 89 projects_path = os.path.join(str(Path.home()), "DLC2Action") 90 if not os.path.exists(projects_path): 91 os.mkdir(projects_path) 92 self.project_path = os.path.join(projects_path, name) 93 self.name = name 94 self.data_type = data_type 95 self.annotation_type = annotation_type 96 self.data_path = data_path 97 self.annotation_path = annotation_path 98 if not os.path.exists(self.project_path): 99 if data_type is None: 100 raise ValueError( 101 "The data_type parameter is necessary when creating a new project!" 102 ) 103 self._initialize_project( 104 data_type, annotation_type, data_path, annotation_path, copy 105 ) 106 else: 107 self.annotation_type, self.data_type = self._read_types() 108 if data_type != self.data_type and data_type is not None: 109 raise ValueError( 110 f"The project has already been initialized with data_type={self.data_type}!" 111 ) 112 if annotation_type != self.annotation_type and annotation_type != "none": 113 raise ValueError( 114 f"The project has already been initialized with annotation_type={self.annotation_type}!" 115 ) 116 self.annotation_path, data_path = self._read_paths() 117 if self.data_path is None: 118 self.data_path = data_path 119 # if data_path != self.data_path and data_path is not None: 120 # raise ValueError( 121 # f"The project has already been initialized with data_path={self.data_path}!" 122 # ) 123 if annotation_path != self.annotation_path and annotation_path is not None: 124 raise ValueError( 125 f"The project has already been initialized with annotation_path={self.annotation_path}!" 126 ) 127 self._update_configs()
Parameters
name : str name of the project data_type : str, optional data type (run Project.data_types() to see available options; has to be provided if the project is being created) annotation_type : str, default 'none' annotation type (run Project.annotation_types() to see available options) projects_path : str, optional path to the projects folder (is filled with ~/DLC2Action by default) data_path : str, optional path to the folder containing input files for the project (has to be provided if the project is being created) annotation_path : str, optional path to the folder containing annotation files for the project copy : bool, default False if True, the files from annotation_path and data_path will be copied to the projects folder; otherwise they will be moved
618 def run_episode( 619 self, 620 episode_name: str, 621 load_episode: str = None, 622 parameters_update: Dict = None, 623 task: TaskDispatcher = None, 624 load_epoch: int = None, 625 load_search: str = None, 626 load_parameters: list = None, 627 round_to_binary: list = None, 628 load_strict: bool = True, 629 n_seeds: int = 1, 630 force: bool = False, 631 suppress_name_check: bool = False, 632 remove_saved_features: bool = False, 633 mask_name: str = None, 634 autostop_metric: str = None, 635 autostop_interval: int = 50, 636 autostop_threshold: float = 0.001, 637 loading_bar: bool = False, 638 trial: Tuple = None, 639 ) -> TaskDispatcher: 640 """ 641 Run an episode 642 643 The task parameters are read from the config files and then updated with the 644 parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the 645 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 646 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 647 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 648 data parameters are used. 649 650 You can use the autostop parameters to finish training when the parameters are not improving. It will be 651 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 652 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 653 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 654 655 Parameters 656 ---------- 657 episode_name : str 658 the episode name 659 load_episode : str, optional 660 the (previously run) episode name to load the model from; if the episode has multiple runs, 661 the new episode will have the same number of runs, each starting with one of the pre-trained models 662 parameters_update : dict, optional 663 the dictionary used to update the parameters from the config files 664 task : TaskDispatcher, optional 665 a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating 666 a new instance) 667 load_epoch : int, optional 668 the epoch to load (if load_episodes is not None); if not provided, the last epoch is used 669 load_search : str, optional 670 the hyperparameter search result to load 671 load_parameters : list, optional 672 a list of string names of the parameters to load from load_search (if not provided, all parameters 673 are loaded) 674 round_to_binary : list, optional 675 a list of string names of the loaded parameters that should be rounded to the nearest power of two 676 load_strict : bool, default True 677 if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and 678 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` 679 n_seeds : int, default 1 680 the number of runs to perform with different random seeds; if `n_seeds > 1`, the episodes will be named 681 `episode_name::seed_index`, e.g. `test_episode::0` and `test_episode::1` 682 force : bool, default False 683 if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!) 684 suppress_name_check : bool, default False 685 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 686 why they are usually forbidden) 687 remove_saved_features : bool, default False 688 if `True`, the dataset will be deleted after training 689 mask_name : str, optional 690 the name of the real_lens to apply 691 autostop_interval : int, default 50 692 the number of epochs to average the autostop metric over 693 autostop_threshold : float, default 0.001 694 the autostop difference threshold 695 autostop_metric : str, optional 696 the autostop metric (can be any one of the tracked metrics of `'loss'`) 697 """ 698 699 if type(n_seeds) is not int or n_seeds < 1: 700 raise ValueError( 701 f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}" 702 ) 703 if n_seeds > 1 and mask_name is not None: 704 raise ValueError("Cannot apply a real_lens with n_seeds > 1") 705 self._check_episode_validity( 706 episode_name, allow_doublecolon=suppress_name_check, force=force 707 ) 708 load_runs = self._episodes().get_runs(load_episode) 709 if len(load_runs) > 1: 710 task = self.run_episodes( 711 episode_names=[ 712 f'{episode_name}::{run.split("::")[-1]}' for run in load_runs 713 ], 714 load_episodes=load_runs, 715 parameters_updates=[parameters_update for _ in load_runs], 716 load_epochs=[load_epoch for _ in load_runs], 717 load_searches=[load_search for _ in load_runs], 718 load_parameters=[load_parameters for _ in load_runs], 719 round_to_binary=[round_to_binary for _ in load_runs], 720 load_strict=[load_strict for _ in load_runs], 721 suppress_name_check=True, 722 force=force, 723 remove_saved_features=False, 724 ) 725 if remove_saved_features: 726 self._remove_stores( 727 { 728 "general": task.general_parameters, 729 "data": task.data_parameters, 730 "features": task.feature_parameters, 731 } 732 ) 733 if n_seeds > 1: 734 warnings.warn( 735 f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs" 736 ) 737 elif n_seeds > 1: 738 self.run_episodes( 739 episode_names=[f"{episode_name}::{i}" for i in range(n_seeds)], 740 load_episodes=[load_episode for _ in range(n_seeds)], 741 parameters_updates=[parameters_update for _ in range(n_seeds)], 742 load_epochs=[load_epoch for _ in range(n_seeds)], 743 load_searches=[load_search for _ in range(n_seeds)], 744 load_parameters=[load_parameters for _ in range(n_seeds)], 745 round_to_binary=[round_to_binary for _ in range(n_seeds)], 746 load_strict=[load_strict for _ in range(n_seeds)], 747 suppress_name_check=True, 748 force=force, 749 remove_saved_features=remove_saved_features, 750 ) 751 else: 752 print(f"TRAINING {episode_name}") 753 try: 754 task, parameters = self._make_task_training( 755 episode_name, 756 load_episode, 757 parameters_update, 758 load_epoch, 759 load_search, 760 load_parameters, 761 round_to_binary, 762 continuing=False, 763 task=task, 764 mask_name=mask_name, 765 load_strict=load_strict, 766 ) 767 self._save_episode( 768 episode_name, 769 parameters, 770 task.behaviors_dict(), 771 norm_stats=task.get_normalization_stats(), 772 ) 773 time_start = time.time() 774 if trial is not None: 775 trial, metric = trial 776 else: 777 trial, metric = None, None 778 logs = task.train( 779 autostop_metric=autostop_metric, 780 autostop_interval=autostop_interval, 781 autostop_threshold=autostop_threshold, 782 loading_bar=loading_bar, 783 trial=trial, 784 optimized_metric=metric, 785 ) 786 time_end = time.time() 787 time_total = time_end - time_start 788 hours = int(time_total // 3600) 789 time_total -= hours * 3600 790 minutes = int(time_total // 60) 791 time_total -= minutes * 60 792 seconds = int(time_total) 793 training_time = f"{hours}:{minutes:02}:{seconds:02}" 794 self._update_episode_results(episode_name, logs, training_time) 795 if remove_saved_features: 796 self._remove_stores(parameters) 797 print("\n") 798 return task 799 800 except Exception as e: 801 if isinstance(e, optuna.exceptions.TrialPruned): 802 raise e 803 else: 804 # if str(e) != f"The {episode_name} episode name is already in use!": 805 # self.remove_episode(episode_name) 806 raise RuntimeError(f"Episode {episode_name} could not run")
Run an episode
The task parameters are read from the config files and then updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.
You can use the autostop parameters to finish training when the parameters are not improving. It will be
stopped if the average value of autostop_metric
over the last autostop_interval
epochs is smaller than
the average over the previous autostop_interval
epochs + autostop_threshold
. For example, if the
current epoch is 120 and autostop_interval
is 50, the averages over epochs 70-120 and 20-70 will be compared.
Parameters
episode_name : str
the episode name
load_episode : str, optional
the (previously run) episode name to load the model from; if the episode has multiple runs,
the new episode will have the same number of runs, each starting with one of the pre-trained models
parameters_update : dict, optional
the dictionary used to update the parameters from the config files
task : TaskDispatcher, optional
a pre-existing TaskDispatcher
object (if provided, the method will update it instead of creating
a new instance)
load_epoch : int, optional
the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
load_search : str, optional
the hyperparameter search result to load
load_parameters : list, optional
a list of string names of the parameters to load from load_search (if not provided, all parameters
are loaded)
round_to_binary : list, optional
a list of string names of the loaded parameters that should be rounded to the nearest power of two
load_strict : bool, default True
if False
, matching weights will be loaded from load_episode
and differences in parameter name lists and
weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError
n_seeds : int, default 1
the number of runs to perform with different random seeds; if n_seeds > 1
, the episodes will be named
episode_name::seed_index
, e.g. test_episode::0
and test_episode::1
force : bool, default False
if True
and an episode with name episode_name
already exists, it will be overwritten (use with caution!)
suppress_name_check : bool, default False
if True
, episode names with a double colon are allowed (please don't use this option unless you understand
why they are usually forbidden)
remove_saved_features : bool, default False
if True
, the dataset will be deleted after training
mask_name : str, optional
the name of the real_lens to apply
autostop_interval : int, default 50
the number of epochs to average the autostop metric over
autostop_threshold : float, default 0.001
the autostop difference threshold
autostop_metric : str, optional
the autostop metric (can be any one of the tracked metrics of 'loss'
)
808 def run_episodes( 809 self, 810 episode_names: List, 811 load_episodes: List = None, 812 parameters_updates: List = None, 813 load_epochs: List = None, 814 load_searches: List = None, 815 load_parameters: List = None, 816 round_to_binary: List = None, 817 load_strict: List = None, 818 force: bool = False, 819 suppress_name_check: bool = False, 820 remove_saved_features: bool = False, 821 ) -> TaskDispatcher: 822 """ 823 Run multiple episodes in sequence (and re-use previously loaded information) 824 825 For each episode, the task parameters are read from the config files and then updated with the 826 parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the 827 previous experiments. All parameters and results are saved in the meta files and can be accessed with the 828 list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the 829 same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same 830 data parameters are used. 831 832 Parameters 833 ---------- 834 episode_names : list 835 a list of strings of episode names 836 load_episodes : list, optional 837 a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, 838 the new episode will have the same number of runs, each starting with one of the pre-trained models 839 parameters_updates : list, optional 840 a list of dictionaries used to update the parameters from the config 841 load_epochs : list, optional 842 a list of integers used to specify the epoch to load (if load_episodes is not None) 843 load_searches : list, optional 844 a list of strings of hyperparameter search results to load 845 load_parameters : list, optional 846 a list of lists of string names of the parameters to load from the searches 847 round_to_binary : list, optional 848 a list of string names of the loaded parameters that should be rounded to the nearest power of two 849 load_strict : list, optional 850 a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from 851 the corresponding episode and differences in parameter name lists and 852 weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for 853 every episode) 854 force : bool, default False 855 if `True` and an episode name is already taken, it will be overwritten (use with caution!) 856 suppress_name_check : bool, default False 857 if `True`, episode names with a double colon are allowed (please don't use this option unless you understand 858 why they are usually forbidden) 859 remove_saved_features : bool, default False 860 if `True`, the dataset will be deleted after training 861 """ 862 863 task = None 864 if load_searches is None: 865 load_searches = [None for _ in episode_names] 866 if load_episodes is None: 867 load_episodes = [None for _ in episode_names] 868 if parameters_updates is None: 869 parameters_updates = [None for _ in episode_names] 870 if load_parameters is None: 871 load_parameters = [None for _ in episode_names] 872 if load_epochs is None: 873 load_epochs = [None for _ in episode_names] 874 if load_strict is None: 875 load_strict = [True for _ in episode_names] 876 for ( 877 parameters_update, 878 episode_name, 879 load_episode, 880 load_epoch, 881 load_search, 882 load_parameters_list, 883 load_strict_value, 884 ) in zip( 885 parameters_updates, 886 episode_names, 887 load_episodes, 888 load_epochs, 889 load_searches, 890 load_parameters, 891 load_strict, 892 ): 893 task = self.run_episode( 894 episode_name, 895 load_episode, 896 parameters_update, 897 task, 898 load_epoch, 899 load_search, 900 load_parameters_list, 901 round_to_binary, 902 load_strict_value, 903 suppress_name_check=suppress_name_check, 904 force=force, 905 remove_saved_features=remove_saved_features, 906 ) 907 return task
Run multiple episodes in sequence (and re-use previously loaded information)
For each episode, the task parameters are read from the config files and then updated with the parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.
Parameters
episode_names : list
a list of strings of episode names
load_episodes : list, optional
a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
the new episode will have the same number of runs, each starting with one of the pre-trained models
parameters_updates : list, optional
a list of dictionaries used to update the parameters from the config
load_epochs : list, optional
a list of integers used to specify the epoch to load (if load_episodes is not None)
load_searches : list, optional
a list of strings of hyperparameter search results to load
load_parameters : list, optional
a list of lists of string names of the parameters to load from the searches
round_to_binary : list, optional
a list of string names of the loaded parameters that should be rounded to the nearest power of two
load_strict : list, optional
a list of boolean values specifying weight loading policy: if False
, matching weights will be loaded from
the corresponding episode and differences in parameter name lists and
weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError
(by default True
for
every episode)
force : bool, default False
if True
and an episode name is already taken, it will be overwritten (use with caution!)
suppress_name_check : bool, default False
if True
, episode names with a double colon are allowed (please don't use this option unless you understand
why they are usually forbidden)
remove_saved_features : bool, default False
if True
, the dataset will be deleted after training
909 def continue_episode( 910 self, 911 episode_name: str, 912 num_epochs: int = None, 913 task: TaskDispatcher = None, 914 n_seeds: int = 1, 915 remove_saved_features: bool = False, 916 device: str = "cuda", 917 num_cpus: int = None, 918 ) -> TaskDispatcher: 919 """ 920 Load an older episode and continue running from the latest checkpoint 921 922 All parameters as well as the model and optimizer state dictionaries are loaded from the episode. 923 924 Parameters 925 ---------- 926 episode_name : str 927 the name of the episode to continue 928 num_epochs : int, optional 929 the new number of epochs 930 task : TaskDispatcher, optional 931 a pre-existing task; if provided, the method will update the task instead of creating a new one 932 (this might save time, mainly on dataset loading) 933 result_average_interval : int, default 5 934 the metric are averaged over the last result_average_interval to be stored in the episodes meta file 935 and displayed by list_episodes() function (the full log is still always available) 936 n_seeds : int, default 1 937 the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name::run_index`, e.g. 938 `test_episode::0` and `test_episode::1` 939 remove_saved_features : bool, default False 940 if `True`, pre-computed features will be deleted after the run 941 device : str, default "cuda" 942 the torch device to use 943 """ 944 945 runs = self._episodes().get_runs(episode_name) 946 for run in runs: 947 print(f"TRAINING {run}") 948 if num_epochs is None and not self._episode(run).unfinished(): 949 continue 950 parameters_update = { 951 "training": { 952 "num_epochs": num_epochs, 953 "device": device, 954 }, 955 "general": {"num_cpus": num_cpus}, 956 } 957 task, parameters = self._make_task_training( 958 run, 959 load_episode=run, 960 parameters_update=parameters_update, 961 continuing=True, 962 task=task, 963 ) 964 time_start = time.time() 965 logs = task.train() 966 time_end = time.time() 967 old_time = self._training_time(run) 968 if not np.isnan(old_time): 969 time_end += old_time 970 time_total = time_end - time_start 971 hours = int(time_total // 3600) 972 time_total -= hours * 3600 973 minutes = int(time_total // 60) 974 time_total -= minutes * 60 975 seconds = int(time_total) 976 training_time = f"{hours}:{minutes:02}:{seconds:02}" 977 else: 978 training_time = np.nan 979 self._save_episode( 980 run, 981 parameters, 982 task.behaviors_dict(), 983 suppress_validation=True, 984 training_time=training_time, 985 norm_stats=task.get_normalization_stats(), 986 ) 987 self._update_episode_results(run, logs) 988 print("\n") 989 if len(runs) < n_seeds: 990 for i in range(len(runs), n_seeds): 991 self.run_episode( 992 f"{episode_name}::{i}", 993 parameters_update=self._episodes().load_parameters(runs[0]), 994 task=task, 995 suppress_name_check=True, 996 ) 997 if remove_saved_features: 998 self._remove_stores(parameters) 999 return task
Load an older episode and continue running from the latest checkpoint
All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
Parameters
episode_name : str
the name of the episode to continue
num_epochs : int, optional
the new number of epochs
task : TaskDispatcher, optional
a pre-existing task; if provided, the method will update the task instead of creating a new one
(this might save time, mainly on dataset loading)
result_average_interval : int, default 5
the metric are averaged over the last result_average_interval to be stored in the episodes meta file
and displayed by list_episodes() function (the full log is still always available)
n_seeds : int, default 1
the number of runs to perform; if n_seeds > 1
, the episodes will be named episode_name::run_index
, e.g.
test_episode::0
and test_episode::1
remove_saved_features : bool, default False
if True
, pre-computed features will be deleted after the run
device : str, default "cuda"
the torch device to use
1001 def run_default_hyperparameter_search( 1002 self, 1003 search_name: str, 1004 model_name: str = None, 1005 metric: str = "f1", 1006 best_n: int = 3, 1007 direction: str = "maximize", 1008 load_episode: str = None, 1009 load_epoch: int = None, 1010 load_strict: bool = True, 1011 prune: bool = True, 1012 force: bool = False, 1013 remove_saved_features: bool = False, 1014 overlap: float = 0, 1015 num_epochs: int = 50, 1016 test_frac: float = 0, 1017 n_trials=150, 1018 device: str = None, 1019 ): 1020 """ 1021 Run an optuna hyperparameter search with default parameters for a model 1022 1023 For the vast majority of cases, optimizing the default parameters should be enough. 1024 Check out `dlc2action.options.model_hyperparameters` for the lists of parameters. 1025 There are also options to set overlap, test fraction and number of epochs parameters for the search without 1026 modifying the project config files. However, if you want something more complex, look into 1027 `Project.run_hyperparameter_search`. 1028 1029 The task parameters are read from the config files and updated with the parameters_update dictionary. 1030 The model can be either initialized from scratch or loaded from a previously run episode. 1031 For each trial, the objective metric is averaged over a few best epochs. 1032 1033 Parameters 1034 ---------- 1035 search_name : str 1036 the name of the search to store it in the meta files and load in run_episode 1037 model_name : str, optional 1038 the name of the model (by default loaded from the project settings, see `project.help('models')` for options) 1039 metric : str, default f1 1040 the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to 1041 `"none"` in the config files, it will be reset to `"macro"` for the search; see `project.help('metrics')` for options 1042 n_trials : int, default 20 1043 the number of optimization trials to run 1044 best_n : int, default 1 1045 the number of epochs to average the metric; if 0, the last value is taken 1046 parameters_update : dict, optional 1047 the parameters update dictionary 1048 direction : {'maximize', 'minimize'} 1049 optimization direction 1050 load_episode : str, optional 1051 the name of the episode to load the model from 1052 load_epoch : int, optional 1053 the epoch to load the model from (if not provided, the last checkpoint is used) 1054 prune : bool, default False 1055 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1056 (with optuna HyperBand pruner) 1057 force : bool, default False 1058 if `True`, existing searches with the same name will be overwritten 1059 remove_saved_features : bool, default False 1060 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1061 device : str, optional 1062 cuda:{i} or cpu, if not given it is read from the default parameters 1063 1064 Returns 1065 ------- 1066 dict 1067 a dictionary of best parameters 1068 """ 1069 1070 if model_name is None: 1071 model_name = self._read_parameters()["general"]["model_name"] 1072 if model_name not in options.model_hyperparameters: 1073 raise ValueError( 1074 f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()" 1075 ) 1076 pars = { 1077 "general": { 1078 "overlap": overlap, 1079 "model_name": model_name, 1080 "metric_functions": {metric}, 1081 }, 1082 "training": {"num_epochs": num_epochs}, 1083 } 1084 if test_frac is not None: 1085 pars["training"]["test_frac"] = test_frac 1086 if not metric.split("_")[-1].isnumeric(): 1087 project_pars = self._read_parameters() 1088 if project_pars["metrics"][metric].get("average") == "none": 1089 pars["metrics"] = {metric: {"average": "macro"}} 1090 if device is not None: 1091 pars["training"]["device"] = device 1092 return self.run_hyperparameter_search( 1093 search_name=search_name, 1094 search_space=options.model_hyperparameters[model_name], 1095 metric=metric, 1096 n_trials=n_trials, 1097 best_n=best_n, 1098 parameters_update=pars, 1099 direction=direction, 1100 load_episode=load_episode, 1101 load_epoch=load_epoch, 1102 load_strict=load_strict, 1103 prune=prune, 1104 force=force, 1105 remove_saved_features=remove_saved_features, 1106 )
Run an optuna hyperparameter search with default parameters for a model
For the vast majority of cases, optimizing the default parameters should be enough.
Check out dlc2action.options.model_hyperparameters
for the lists of parameters.
There are also options to set overlap, test fraction and number of epochs parameters for the search without
modifying the project config files. However, if you want something more complex, look into
Project.run_hyperparameter_search
.
The task parameters are read from the config files and updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from a previously run episode. For each trial, the objective metric is averaged over a few best epochs.
Parameters
search_name : str
the name of the search to store it in the meta files and load in run_episode
model_name : str, optional
the name of the model (by default loaded from the project settings, see project.help('models')
for options)
metric : str, default f1
the metric to maximize/minimize (see direction); if the metric has an "average"
parameter and it is set to
"none"
in the config files, it will be reset to "macro"
for the search; see project.help('metrics')
for options
n_trials : int, default 20
the number of optimization trials to run
best_n : int, default 1
the number of epochs to average the metric; if 0, the last value is taken
parameters_update : dict, optional
the parameters update dictionary
direction : {'maximize', 'minimize'}
optimization direction
load_episode : str, optional
the name of the episode to load the model from
load_epoch : int, optional
the epoch to load the model from (if not provided, the last checkpoint is used)
prune : bool, default False
if True
, experiments where the optimized metric is improving too slowly will be terminated
(with optuna HyperBand pruner)
force : bool, default False
if True
, existing searches with the same name will be overwritten
remove_saved_features : bool, default False
if True
, pre-computed features will be deleted after each run (if the data parameters change)
device : str, optional
cuda:{i} or cpu, if not given it is read from the default parameters
Returns
dict a dictionary of best parameters
1108 def run_hyperparameter_search( 1109 self, 1110 search_name: str, 1111 search_space: Dict, 1112 metric: str = "f1", 1113 n_trials: int = 20, 1114 best_n: int = 1, 1115 parameters_update: Dict = None, 1116 direction: str = "maximize", 1117 load_episode: str = None, 1118 load_epoch: int = None, 1119 load_strict: bool = True, 1120 prune: bool = False, 1121 force: bool = False, 1122 remove_saved_features: bool = False, 1123 ) -> Dict: 1124 """ 1125 Run an optuna hyperparameter search 1126 1127 For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`. 1128 1129 To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is 1130 a dictionary where keys are model names and values are default search spaces. 1131 1132 The task parameters are read from the config files and updated with the parameters_update dictionary. 1133 The model can be either initialized from scratch or loaded from a previously run episode. 1134 For each trial, the objective metric is averaged over a few best epochs. 1135 1136 Parameters 1137 ---------- 1138 search_name : str 1139 the name of the search to store it in the meta files and load in run_episode 1140 search_space : dict 1141 a dictionary representing the search space; of this general structure: 1142 {'group/param_name': ('float/int/float_log/int_log', start, end), 1143 'group/param_name': ('categorical', [choices])}, e.g. 1144 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 1145 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}; 1146 metric : str, default f1 1147 the metric to maximize/minimize (see direction) 1148 n_trials : int, default 20 1149 the number of optimization trials to run 1150 best_n : int, default 1 1151 the number of epochs to average the metric; if 0, the last value is taken 1152 parameters_update : dict, optional 1153 the parameters update dictionary 1154 direction : {'maximize', 'minimize'} 1155 optimization direction 1156 load_episode : str, optional 1157 the name of the episode to load the model from 1158 load_epoch : int, optional 1159 the epoch to load the model from (if not provided, the last checkpoint is used) 1160 prune : bool, default False 1161 if `True`, experiments where the optimized metric is improving too slowly will be terminated 1162 (with optuna HyperBand pruner) 1163 force : bool, default False 1164 if `True`, existing searches with the same name will be overwritten 1165 remove_saved_features : bool, default False 1166 if `True`, pre-computed features will be deleted after each run (if the data parameters change) 1167 1168 Returns 1169 ------- 1170 dict 1171 a dictionary of best parameters 1172 """ 1173 1174 self._check_search_validity(search_name, force=force) 1175 print(f"SEARCH {search_name}") 1176 self.remove_episode(f"_{search_name}") 1177 if parameters_update is None: 1178 parameters_update = {} 1179 parameters_update = self._update( 1180 parameters_update, {"general": {"metric_functions": {metric}}} 1181 ) 1182 parameters = self._make_parameters( 1183 f"_{search_name}", 1184 load_episode, 1185 parameters_update, 1186 parameters_update_second={"training": {"model_save_path": None}}, 1187 load_epoch=load_epoch, 1188 load_strict=load_strict, 1189 ) 1190 task = None 1191 1192 if prune: 1193 pruner = optuna.pruners.HyperbandPruner() 1194 else: 1195 pruner = optuna.pruners.NopPruner() 1196 study = optuna.create_study(direction=direction, pruner=pruner) 1197 runner = _Runner( 1198 search_space=search_space, 1199 load_episode=load_episode, 1200 load_epoch=load_epoch, 1201 metric=metric, 1202 average=best_n, 1203 task=task, 1204 remove_saved_features=remove_saved_features, 1205 project=self, 1206 search_name=search_name, 1207 ) 1208 study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials) 1209 search_path = self._search_path(search_name) 1210 os.mkdir(search_path) 1211 fig = optuna.visualization.plot_contour(study) 1212 plotly.offline.plot( 1213 fig, filename=os.path.join(search_path, f"{search_name}_contour.html") 1214 ) 1215 fig = optuna.visualization.plot_param_importances(study) 1216 plotly.offline.plot( 1217 fig, filename=os.path.join(search_path, f"{search_name}_importances.html") 1218 ) 1219 best_params = study.best_params 1220 best_value = study.best_value 1221 self._save_search( 1222 search_name, 1223 parameters, 1224 n_trials, 1225 best_params, 1226 best_value, 1227 metric, 1228 search_space, 1229 ) 1230 self.remove_episode(f"_{search_name}") 1231 runner.clean() 1232 print(f"best parameters: {best_params}") 1233 print("\n") 1234 return best_params
Run an optuna hyperparameter search
For a simpler function that fits most use cases, check out Project.run_default_hyperparameter_search()
.
To use a default search space with this method, import dlc2action.options.model_hyperparameters
. It is
a dictionary where keys are model names and values are default search spaces.
The task parameters are read from the config files and updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from a previously run episode. For each trial, the objective metric is averaged over a few best epochs.
Parameters
search_name : str
the name of the search to store it in the meta files and load in run_episode
search_space : dict
a dictionary representing the search space; of this general structure:
{'group/param_name': ('float/int/float_log/int_log', start, end),
'group/param_name': ('categorical', [choices])}, e.g.
{'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
'data/feature_extraction': ('categorical', ['kinematic', 'bones'])};
metric : str, default f1
the metric to maximize/minimize (see direction)
n_trials : int, default 20
the number of optimization trials to run
best_n : int, default 1
the number of epochs to average the metric; if 0, the last value is taken
parameters_update : dict, optional
the parameters update dictionary
direction : {'maximize', 'minimize'}
optimization direction
load_episode : str, optional
the name of the episode to load the model from
load_epoch : int, optional
the epoch to load the model from (if not provided, the last checkpoint is used)
prune : bool, default False
if True
, experiments where the optimized metric is improving too slowly will be terminated
(with optuna HyperBand pruner)
force : bool, default False
if True
, existing searches with the same name will be overwritten
remove_saved_features : bool, default False
if True
, pre-computed features will be deleted after each run (if the data parameters change)
Returns
dict a dictionary of best parameters
1236 def run_prediction( 1237 self, 1238 prediction_name: str, 1239 episode_names: List, 1240 load_epochs: List = None, 1241 parameters_update: Dict = None, 1242 augment_n: int = 10, 1243 data_path: str = None, 1244 mode: str = "all", 1245 file_paths: Set = None, 1246 remove_saved_features: bool = False, 1247 submission: bool = False, 1248 frame_number_map_file: str = None, 1249 force: bool = False, 1250 embedding: bool = False, 1251 ) -> None: 1252 """ 1253 Load models from previously run episodes to generate a prediction 1254 1255 The probabilities predicted by the models are averaged. 1256 Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder 1257 under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level 1258 keys are the video ids, the second-level keys are the clip ids (like individual names) and the values 1259 are the prediction arrays. 1260 1261 Parameters 1262 ---------- 1263 prediction_name : str 1264 the name of the prediction 1265 episode_names : list 1266 a list of string episode names to load the models from 1267 load_epochs : list, optional 1268 a list of integer epoch indices to load the model from; if None, the last ones are used 1269 parameters_update : dict, optional 1270 a dictionary of parameter updates 1271 augment_n : int, default 10 1272 the number of augmentations to average over 1273 data_path : str, optional 1274 the data path to run the prediction for 1275 mode : {'all', 'test', 'val', 'train'} 1276 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 1277 file_paths : set, optional 1278 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 1279 for 1280 remove_saved_features : bool, default False 1281 if `True`, pre-computed features will be deleted 1282 submission : bool, default False 1283 if `True`, a MABe-22 style submission file is generated 1284 frame_number_map_file : str, optional 1285 path to the frame number map file 1286 force : bool, default False 1287 if `True`, existing prediction with this name will be overwritten 1288 """ 1289 1290 self._check_prediction_validity(prediction_name, force=force) 1291 print(f"PREDICTION {prediction_name}") 1292 if submission: 1293 task = ... 1294 # TODO: add submission option to _make_prediction 1295 predicted = task.generate_submission( 1296 frame_number_map_file=frame_number_map_file, 1297 dataset=mode, 1298 augment_n=augment_n, 1299 ) 1300 folder = os.path.join( 1301 self.project_path, 1302 "results", 1303 "predictions", 1304 f"{prediction_name}", 1305 ) 1306 filename = os.path.join(folder, f"{prediction_name}.npy") 1307 np.save(filename, predicted, allow_pickle=True) 1308 else: 1309 try: 1310 ( 1311 task, 1312 parameters, 1313 mode, 1314 prediction, 1315 inference_time, 1316 ) = self._make_prediction( 1317 prediction_name, 1318 episode_names, 1319 load_epochs, 1320 parameters_update, 1321 data_path, 1322 file_paths, 1323 mode, 1324 augment_n, 1325 evaluate=False, 1326 embedding=embedding, 1327 ) 1328 predicted = task.dataset(mode).generate_full_length_prediction( 1329 prediction 1330 ) 1331 except ValueError: 1332 ( 1333 task, 1334 parameters, 1335 mode, 1336 predicted, 1337 inference_time, 1338 ) = self._aggregate_predictions( 1339 prediction_name, 1340 episode_names, 1341 load_epochs, 1342 parameters_update, 1343 data_path, 1344 file_paths, 1345 mode, 1346 augment_n, 1347 evaluate=False, 1348 embedding=embedding, 1349 ) 1350 folder = self.prediction_path(prediction_name) 1351 os.mkdir(folder) 1352 for video_id, prediction in predicted.items(): 1353 with open( 1354 os.path.join( 1355 folder, video_id + f"_{prediction_name}_prediction.pickle" 1356 ), 1357 "wb", 1358 ) as f: 1359 prediction["min_frames"], prediction["max_frames"] = task.dataset( 1360 mode 1361 ).get_min_max_frames(video_id) 1362 behavior_indices = sorted( 1363 [key for key in task.behaviors_dict() if key != -100] 1364 ) 1365 prediction["behaviors"] = [ 1366 task.behaviors_dict()[key] for key in behavior_indices 1367 ] 1368 pickle.dump(prediction, f) 1369 if remove_saved_features: 1370 self._remove_stores(parameters) 1371 self._save_prediction( 1372 prediction_name, 1373 parameters, 1374 task.behaviors_dict(), 1375 embedding, 1376 inference_time, 1377 ) 1378 print("\n")
Load models from previously run episodes to generate a prediction
The probabilities predicted by the models are averaged.
Unless submission
is True
, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
are the prediction arrays.
Parameters
prediction_name : str
the name of the prediction
episode_names : list
a list of string episode names to load the models from
load_epochs : list, optional
a list of integer epoch indices to load the model from; if None, the last ones are used
parameters_update : dict, optional
a dictionary of parameter updates
augment_n : int, default 10
the number of augmentations to average over
data_path : str, optional
the data path to run the prediction for
mode : {'all', 'test', 'val', 'train'}
the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
file_paths : set, optional
a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
for
remove_saved_features : bool, default False
if True
, pre-computed features will be deleted
submission : bool, default False
if True
, a MABe-22 style submission file is generated
frame_number_map_file : str, optional
path to the frame number map file
force : bool, default False
if True
, existing prediction with this name will be overwritten
1380 def evaluate_prediction( 1381 self, 1382 prediction_name: str, 1383 parameters_update: Dict = None, 1384 data_path: str = None, 1385 file_paths: Set = None, 1386 mode: str = None, 1387 remove_saved_features: bool = False, 1388 ) -> Tuple[float, dict]: 1389 1390 with open( 1391 os.path.join( 1392 self.project_path, "results", "predictions", f"{prediction_name}.pickle" 1393 ), 1394 "rb", 1395 ) as f: 1396 prediction = pickle.load(f) 1397 if parameters_update is None: 1398 parameters_update = {} 1399 parameters_update = self._update( 1400 self._predictions().load_parameters(prediction_name), parameters_update 1401 ) 1402 parameters_update.pop("model") 1403 task, parameters, mode = self._make_task_prediction( 1404 "_", 1405 load_episode=None, 1406 parameters_update=parameters_update, 1407 data_path=data_path, 1408 file_paths=file_paths, 1409 mode=mode, 1410 ) 1411 results = task.evaluate_prediction(prediction, data=mode) 1412 if remove_saved_features: 1413 self._remove_stores(parameters) 1414 print("\n") 1415 return results
1417 def evaluate( 1418 self, 1419 episode_names: List, 1420 load_epochs: List = None, 1421 augment_n: int = 0, 1422 data_path: str = None, 1423 file_paths: Set = None, 1424 mode: str = None, 1425 parameters_update: Dict = None, 1426 multiple_episode_policy: str = "average", 1427 remove_saved_features: bool = False, 1428 skip_updating_meta: bool = True, 1429 ) -> Dict: 1430 """ 1431 Load one or several models from previously run episodes to make an evaluation 1432 1433 By default it will run on the test (or validation, if there is no test) subset of the project dataset. 1434 1435 Parameters 1436 ---------- 1437 episode_names : list 1438 a list of string episode names to load the models from 1439 load_epochs : list, optional 1440 a list of integer epoch indices to load the model from; if None, the last ones are used 1441 augment_n : int, default 0 1442 the number of augmentations to average over 1443 data_path : str, optional 1444 the data path to run the prediction for 1445 file_paths : set, optional 1446 a set of files to run the prediction for 1447 mode : {'test', 'val', 'train', 'all'} 1448 the subset of the data to make the prediction for (forced to 'all' if data_path is not None; 1449 by default 'test' if test subset is not empty and 'val' otherwise) 1450 parameters_update : dict, optional 1451 a dictionary with parameter updates (cannot change model parameters) 1452 remove_saved_features : bool, default False 1453 if `True`, the dataset will be deleted 1454 1455 Returns 1456 ------- 1457 metric : dict 1458 a dictionary of average values of metric functions 1459 """ 1460 1461 names = [] 1462 for episode_name in episode_names: 1463 names += self._episodes().get_runs(episode_name) 1464 if len(set(episode_names)) == 1: 1465 print(f"EVALUATION {episode_names[0]}") 1466 else: 1467 print(f"EVALUATION {episode_names}") 1468 if len(names) > 1: 1469 evaluate = True 1470 else: 1471 evaluate = False 1472 if multiple_episode_policy == "average": 1473 try: 1474 ( 1475 task, 1476 parameters, 1477 mode, 1478 prediction, 1479 inference_time, 1480 ) = self._make_prediction( 1481 "_", 1482 episode_names, 1483 load_epochs, 1484 parameters_update, 1485 mode=mode, 1486 data_path=data_path, 1487 file_paths=file_paths, 1488 augment_n=augment_n, 1489 evaluate=evaluate, 1490 ) 1491 except: 1492 ( 1493 task, 1494 parameters, 1495 mode, 1496 prediction, 1497 inference_time, 1498 ) = self._aggregate_predictions( 1499 "_", 1500 episode_names, 1501 load_epochs, 1502 parameters_update, 1503 mode=mode, 1504 data_path=data_path, 1505 file_paths=file_paths, 1506 augment_n=augment_n, 1507 evaluate=evaluate, 1508 ) 1509 print("AGGREGATED:") 1510 _, results = task.evaluate_prediction(prediction, data=mode) 1511 if len(names) == 1 and mode == "val" and not skip_updating_meta: 1512 self._update_episode_metrics(names[0], results) 1513 elif multiple_episode_policy == "statistics": 1514 values = defaultdict(lambda: []) 1515 task = None 1516 for name in names: 1517 ( 1518 task, 1519 parameters, 1520 mode, 1521 prediction, 1522 inference_time, 1523 ) = self._make_prediction( 1524 "_", 1525 [name], 1526 load_epochs, 1527 parameters_update, 1528 mode=mode, 1529 data_path=data_path, 1530 file_paths=file_paths, 1531 augment_n=augment_n, 1532 evaluate=evaluate, 1533 task=task, 1534 ) 1535 _, metrics = task.evaluate_prediction(prediction, data=mode) 1536 for name, value in metrics.items(): 1537 values[name].append(value) 1538 if mode == "val" and not skip_updating_meta: 1539 self._update_episode_metrics(name, metrics) 1540 results = defaultdict(lambda: {}) 1541 mean_string = "" 1542 std_string = "" 1543 for key, value_list in values.items(): 1544 results[key]["mean"] = np.mean(value_list) 1545 results[key]["std"] = np.std(value_list) 1546 mean_string += f"{key} {np.mean(value_list):.3f}, " 1547 std_string += f"{key} {np.std(value_list):.3f}, " 1548 print("MEAN:") 1549 print(mean_string) 1550 print("STD:") 1551 print(std_string) 1552 else: 1553 raise ValueError( 1554 f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose " 1555 f"from ['average', 'statistics']" 1556 ) 1557 if len(names) > 0 and remove_saved_features: 1558 self._remove_stores(parameters) 1559 print(f"Inference time: {inference_time}") 1560 print("\n") 1561 return results
Load one or several models from previously run episodes to make an evaluation
By default it will run on the test (or validation, if there is no test) subset of the project dataset.
Parameters
episode_names : list
a list of string episode names to load the models from
load_epochs : list, optional
a list of integer epoch indices to load the model from; if None, the last ones are used
augment_n : int, default 0
the number of augmentations to average over
data_path : str, optional
the data path to run the prediction for
file_paths : set, optional
a set of files to run the prediction for
mode : {'test', 'val', 'train', 'all'}
the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
by default 'test' if test subset is not empty and 'val' otherwise)
parameters_update : dict, optional
a dictionary with parameter updates (cannot change model parameters)
remove_saved_features : bool, default False
if True
, the dataset will be deleted
Returns
metric : dict a dictionary of average values of metric functions
1639 def list_episodes( 1640 self, 1641 episode_names: List = None, 1642 value_filter: str = "", 1643 display_parameters: List = None, 1644 print_results: bool = True, 1645 ) -> pd.DataFrame: 1646 """ 1647 Get a filtered pandas dataframe with episode metadata 1648 1649 Parameters 1650 ---------- 1651 episode_names : list 1652 a list of strings of episode names 1653 value_filter : str 1654 a string of filters to apply; of this general structure: 1655 'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g. 1656 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' 1657 display_parameters : list 1658 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1659 print_results : bool, default True 1660 if True, the result will be printed to standard output 1661 1662 Returns 1663 ------- 1664 pd.DataFrame 1665 the filtered dataframe 1666 """ 1667 1668 episodes = self._episodes().list_episodes( 1669 episode_names, value_filter, display_parameters 1670 ) 1671 if print_results: 1672 print("TRAINING EPISODES") 1673 print(episodes) 1674 print("\n") 1675 return episodes
Get a filtered pandas dataframe with episode metadata
Parameters
episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1::(>/<=/>=/=)value1,group_name2/par_name2::(>/<=/>=/=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output
Returns
pd.DataFrame the filtered dataframe
1677 def list_predictions( 1678 self, 1679 episode_names: List = None, 1680 value_filter: str = "", 1681 display_parameters: List = None, 1682 print_results: bool = True, 1683 ) -> pd.DataFrame: 1684 """ 1685 Get a filtered pandas dataframe with prediction metadata 1686 1687 Parameters 1688 ---------- 1689 episode_names : list 1690 a list of strings of episode names 1691 value_filter : str 1692 a string of filters to apply; of this general structure: 1693 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 1694 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 1695 display_parameters : list 1696 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1697 print_results : bool, default True 1698 if True, the result will be printed to standard output 1699 1700 Returns 1701 ------- 1702 pd.DataFrame 1703 the filtered dataframe 1704 """ 1705 1706 predictions = self._predictions().list_episodes( 1707 episode_names, value_filter, display_parameters 1708 ) 1709 if print_results: 1710 print("PREDICTIONS") 1711 print(predictions) 1712 print("\n") 1713 return predictions
Get a filtered pandas dataframe with prediction metadata
Parameters
episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output
Returns
pd.DataFrame the filtered dataframe
1715 def list_searches( 1716 self, 1717 search_names: List = None, 1718 value_filter: str = "", 1719 display_parameters: List = None, 1720 print_results: bool = True, 1721 ) -> pd.DataFrame: 1722 """ 1723 Get a filtered pandas dataframe with hyperparameter search metadata 1724 1725 Parameters 1726 ---------- 1727 search_names : list 1728 a list of strings of search names 1729 value_filter : str 1730 a string of filters to apply; of this general structure: 1731 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 1732 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' 1733 display_parameters : list 1734 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 1735 print_results : bool, default True 1736 if True, the result will be printed to standard output 1737 1738 Returns 1739 ------- 1740 pd.DataFrame 1741 the filtered dataframe 1742 """ 1743 1744 searches = self._searches().list_episodes( 1745 search_names, value_filter, display_parameters 1746 ) 1747 if print_results: 1748 print("SEARCHES") 1749 print(searches) 1750 print("\n") 1751 return searches
Get a filtered pandas dataframe with hyperparameter search metadata
Parameters
search_names : list a list of strings of search names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output
Returns
pd.DataFrame the filtered dataframe
1753 def get_best_parameters( 1754 self, 1755 search_name: str, 1756 round_to_binary: List = None, 1757 ): 1758 params, model = self._searches().get_best_params( 1759 search_name, round_to_binary=round_to_binary 1760 ) 1761 params = self._update(params, {"general": {"model_name": model}}) 1762 return params
1764 def list_best_parameters( 1765 self, search_name: str, print_results: bool = True 1766 ) -> Dict: 1767 """ 1768 Get the raw dictionary of best parameters found by a search 1769 1770 Parameters 1771 ---------- 1772 search_name : str 1773 the name of the search 1774 print_results : bool, default True 1775 if True, the result will be printed to standard output 1776 1777 Returns 1778 ------- 1779 best_params : dict 1780 a dictionary of the best parameters where the keys are in '{group}/{name}' format 1781 """ 1782 1783 params = self._searches().get_best_params_raw(search_name) 1784 if print_results: 1785 print(f"SEARCH RESULTS {search_name}") 1786 for k, v in params.items(): 1787 print(f"{k}: {v}") 1788 print("\n") 1789 return params
Get the raw dictionary of best parameters found by a search
Parameters
search_name : str the name of the search print_results : bool, default True if True, the result will be printed to standard output
Returns
best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format
1791 def plot_episodes( 1792 self, 1793 episode_names: List, 1794 metrics: List, 1795 modes: List = None, 1796 title: str = None, 1797 episode_labels: List = None, 1798 save_path: str = None, 1799 add_hlines: List = None, 1800 epoch_limits: List = None, 1801 colors: List = None, 1802 add_highpoint_hlines: bool = False, 1803 ) -> None: 1804 """ 1805 Plot episode training curves 1806 1807 Parameters 1808 ---------- 1809 episode_names : list 1810 a list of episode names to plot; to plot to episodes in one line combine them in a list 1811 (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) 1812 metrics : list 1813 a list of metric to plot 1814 modes : list, optional 1815 a list of modes to plot ('train' and/or 'val'; `['val']` by default) 1816 title : str, optional 1817 title for the plot 1818 episode_labels : list, optional 1819 a list of strings used to label the curves (has to be the same length as episode_names) 1820 save_path : str, optional 1821 the path to save the resulting plot 1822 add_hlines : list, optional 1823 a list of float values (or (value, label) tuples) to mark with horizontal lines 1824 colors: list, optional 1825 a list of matplotlib colors 1826 add_highpoint_hlines : bool, default False 1827 if `True`, horizontal lines will be added at the highest value of each episode 1828 """ 1829 1830 if modes is None: 1831 modes = ["val"] 1832 if add_hlines is None: 1833 add_hlines = [] 1834 logs = [] 1835 epochs = [] 1836 labels = [] 1837 if episode_labels is not None: 1838 assert len(episode_labels) == len(episode_names) 1839 for name_i, name in enumerate(episode_names): 1840 log_params = product(metrics, modes) 1841 for metric, mode in log_params: 1842 if episode_labels is not None: 1843 label = episode_labels[name_i] 1844 else: 1845 label = deepcopy(name) 1846 if len(modes) != 1: 1847 label += f"_{mode}" 1848 if len(metrics) != 1: 1849 label += f"_{metric}" 1850 labels.append(label) 1851 if isinstance(name, Iterable) and not isinstance(name, str): 1852 epoch_list = defaultdict(lambda: []) 1853 multi_logs = defaultdict(lambda: []) 1854 for i, n in enumerate(name): 1855 runs = self._episodes().get_runs(n) 1856 if len(runs) > 1: 1857 for run in runs: 1858 index = run.split("::")[-1] 1859 if multi_logs[index] == []: 1860 if multi_logs["null"] is None: 1861 raise RuntimeError( 1862 "The run indices are not consistent across episodes!" 1863 ) 1864 else: 1865 multi_logs[index] += multi_logs["null"] 1866 multi_logs[index] += list( 1867 self._episode(run).get_metric_log(mode, metric) 1868 ) 1869 start = ( 1870 0 1871 if len(epoch_list[index]) == 0 1872 else epoch_list[index][-1] 1873 ) 1874 epoch_list[index] += [ 1875 x + start 1876 for x in self._episode(run).get_epoch_list(mode) 1877 ] 1878 multi_logs["null"] = None 1879 else: 1880 if len(multi_logs.keys()) > 1: 1881 raise RuntimeError( 1882 "Cannot plot a single-run episode after a multi-run episode!" 1883 ) 1884 multi_logs["null"] += list( 1885 self._episode(n).get_metric_log(mode, metric) 1886 ) 1887 start = ( 1888 0 1889 if len(epoch_list["null"]) == 0 1890 else epoch_list["null"][-1] 1891 ) 1892 epoch_list["null"] += [ 1893 x + start for x in self._episode(n).get_epoch_list(mode) 1894 ] 1895 if len(multi_logs.keys()) == 1: 1896 log = multi_logs["null"] 1897 epochs.append(epoch_list["null"]) 1898 else: 1899 log = tuple([v for k, v in multi_logs.items() if k != "null"]) 1900 epochs.append( 1901 tuple([v for k, v in epoch_list.items() if k != "null"]) 1902 ) 1903 else: 1904 runs = self._episodes().get_runs(name) 1905 if len(runs) > 1: 1906 log = [] 1907 for run in runs: 1908 tracked_metrics = self._episode(run).get_metrics() 1909 if metric in tracked_metrics: 1910 log.append( 1911 list( 1912 self._episode(run).get_metric_log(mode, metric) 1913 ) 1914 ) 1915 else: 1916 relevant = [] 1917 for m in tracked_metrics: 1918 m_split = m.split("_") 1919 if ( 1920 "_".join(m_split[:-1]) == metric 1921 and m_split[-1].isnumeric() 1922 ): 1923 relevant.append(m) 1924 if len(relevant) == 0: 1925 raise ValueError( 1926 f"The {metric} metric was not tracked at {run}" 1927 ) 1928 arr = 0 1929 for m in relevant: 1930 arr += self._episode(run).get_metric_log(mode, m) 1931 arr /= len(relevant) 1932 log.append(list(arr)) 1933 log = tuple(log) 1934 epochs.append( 1935 tuple( 1936 [ 1937 self._episode(run).get_epoch_list(mode) 1938 for run in runs 1939 ] 1940 ) 1941 ) 1942 else: 1943 tracked_metrics = self._episode(name).get_metrics() 1944 if metric in tracked_metrics: 1945 log = list(self._episode(name).get_metric_log(mode, metric)) 1946 else: 1947 relevant = [] 1948 for m in tracked_metrics: 1949 m_split = m.split("_") 1950 if ( 1951 "_".join(m_split[:-1]) == metric 1952 and m_split[-1].isnumeric() 1953 ): 1954 relevant.append(m) 1955 if len(relevant) == 0: 1956 raise ValueError( 1957 f"The {metric} metric was not tracked at {name}" 1958 ) 1959 arr = 0 1960 for m in relevant: 1961 arr += self._episode(name).get_metric_log(mode, m) 1962 arr /= len(relevant) 1963 log = list(arr) 1964 epochs.append(self._episode(name).get_epoch_list(mode)) 1965 logs.append(log) 1966 # if episode_labels is not None: 1967 # print(f'{len(episode_labels)=}, {len(logs)=}') 1968 # if len(episode_labels) != len(logs): 1969 1970 # raise ValueError( 1971 # f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of " 1972 # f"curves ({len(logs)})!" 1973 # ) 1974 # else: 1975 # labels = episode_labels 1976 if colors is None: 1977 colors = cm.rainbow(np.linspace(0, 1, len(logs))) 1978 if len(colors) != len(logs): 1979 raise ValueError( 1980 "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!" 1981 ) 1982 plt.figure() 1983 length = 0 1984 for log, label, color, epoch_list in zip(logs, labels, colors, epochs): 1985 if type(log) is list: 1986 if len(log) > length: 1987 length = len(log) 1988 plt.plot( 1989 epoch_list, 1990 log, 1991 label=label, 1992 color=color, 1993 ) 1994 if add_highpoint_hlines: 1995 plt.axhline(np.max(log), linestyle="dashed", color=color) 1996 else: 1997 for l, xx in zip(log, epoch_list): 1998 if len(l) > length: 1999 length = len(l) 2000 plt.plot( 2001 xx, 2002 l, 2003 color=color, 2004 alpha=0.2, 2005 ) 2006 if not all([len(x) == len(log[0]) for x in log]): 2007 warnings.warn( 2008 f"Got logs with unequal lengths in parallel runs for {label}" 2009 ) 2010 log = list(log) 2011 epoch_list = list(epoch_list) 2012 for i, x in enumerate(epoch_list): 2013 to_remove = [] 2014 for j, y in enumerate(x[1:]): 2015 if y <= x[j - 1]: 2016 y_ind = x.index(y) 2017 to_remove += list(range(y_ind, j)) 2018 epoch_list[i] = [ 2019 y for j, y in enumerate(x) if j not in to_remove 2020 ] 2021 log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove] 2022 length = min([len(x) for x in log]) 2023 for i in range(len(log)): 2024 log[i] = log[i][:length] 2025 epoch_list[i] = epoch_list[i][:length] 2026 if not all([x == epoch_list[0] for x in epoch_list]): 2027 raise RuntimeError( 2028 f"Got different epoch indices in parallel runs for {label}" 2029 ) 2030 mean = np.array(log).mean(0) 2031 plt.plot( 2032 epoch_list[0], 2033 mean, 2034 label=label, 2035 color=color, 2036 ) 2037 if add_highpoint_hlines: 2038 plt.axhline(np.max(mean), linestyle="dashed", color=color) 2039 for x in add_hlines: 2040 label = None 2041 if isinstance(x, Iterable): 2042 x, label = x 2043 plt.axhline(x, label=label) 2044 plt.xlim((0, length)) 2045 2046 plt.legend() 2047 plt.xlabel("epochs") 2048 if len(metrics) == 1: 2049 plt.ylabel(metrics[0]) 2050 else: 2051 plt.ylabel("value") 2052 if title is None: 2053 if len(episode_names) == 1: 2054 title = episode_names[0] 2055 elif len(metrics) == 1: 2056 title = metrics[0] 2057 if epoch_limits is not None: 2058 plt.xlim(epoch_limits) 2059 if title is not None: 2060 plt.title(title) 2061 plt.show() 2062 if save_path is not None: 2063 plt.savefig(save_path)
Plot episode training curves
Parameters
episode_names : list
a list of episode names to plot; to plot to episodes in one line combine them in a list
(e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
metrics : list
a list of metric to plot
modes : list, optional
a list of modes to plot ('train' and/or 'val'; ['val']
by default)
title : str, optional
title for the plot
episode_labels : list, optional
a list of strings used to label the curves (has to be the same length as episode_names)
save_path : str, optional
the path to save the resulting plot
add_hlines : list, optional
a list of float values (or (value, label) tuples) to mark with horizontal lines
colors: list, optional
a list of matplotlib colors
add_highpoint_hlines : bool, default False
if True
, horizontal lines will be added at the highest value of each episode
2065 def update_parameters( 2066 self, 2067 parameters_update: Dict = None, 2068 load_search: str = None, 2069 load_parameters: List = None, 2070 round_to_binary: List = None, 2071 ) -> None: 2072 """ 2073 Update the parameters in the project config files 2074 2075 Parameters 2076 ---------- 2077 parameters_update : dict, optional 2078 a dictionary of parameter updates 2079 load_search : str, optional 2080 the name of hyperparameter search results to load to config 2081 load_parameters : list, optional 2082 a list of lists of string names of the parameters to load from the searches 2083 round_to_binary : list, optional 2084 a list of string names of the loaded parameters that should be rounded to the nearest power of two 2085 """ 2086 2087 keys = [ 2088 "general", 2089 "losses", 2090 "metrics", 2091 "ssl", 2092 "training", 2093 "data", 2094 ] 2095 parameters = self._read_parameters(catch_blanks=False) 2096 if parameters_update is not None: 2097 if "model" in parameters_update: 2098 model_params = parameters_update.pop("model") 2099 else: 2100 model_params = None 2101 if "features" in parameters_update: 2102 feat_params = parameters_update.pop("features") 2103 else: 2104 feat_params = None 2105 if "augmentations" in parameters_update: 2106 aug_params = parameters_update.pop("augmentations") 2107 else: 2108 aug_params = None 2109 parameters = self._update(parameters, parameters_update) 2110 model_name = parameters["general"]["model_name"] 2111 parameters["model"] = self._open_yaml( 2112 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2113 ) 2114 if model_params is not None: 2115 parameters["model"] = self._update(parameters["model"], model_params) 2116 feat_name = parameters["general"]["feature_extraction"] 2117 parameters["features"] = self._open_yaml( 2118 os.path.join( 2119 self.project_path, "config", "features", f"{feat_name}.yaml" 2120 ) 2121 ) 2122 if feat_params is not None: 2123 parameters["features"] = self._update( 2124 parameters["features"], feat_params 2125 ) 2126 aug_name = options.extractor_to_transformer[ 2127 parameters["general"]["feature_extraction"] 2128 ] 2129 parameters["augmentations"] = self._open_yaml( 2130 os.path.join( 2131 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2132 ) 2133 ) 2134 if aug_params is not None: 2135 parameters["augmentations"] = self._update( 2136 parameters["augmentations"], aug_params 2137 ) 2138 if load_search is not None: 2139 parameters_update, model_name = self._searches().get_best_params( 2140 load_search, load_parameters, round_to_binary 2141 ) 2142 parameters["general"]["model_name"] = model_name 2143 parameters["model"] = self._open_yaml( 2144 os.path.join(self.project_path, "config", "model", f"{model_name}.yaml") 2145 ) 2146 parameters = self._update(parameters, parameters_update) 2147 for key in keys: 2148 with open( 2149 os.path.join(self.project_path, "config", f"{key}.yaml"), "w", encoding="utf-8" 2150 ) as f: 2151 YAML().dump(parameters[key], f) 2152 model_name = parameters["general"]["model_name"] 2153 model_path = os.path.join( 2154 self.project_path, "config", "model", f"{model_name}.yaml" 2155 ) 2156 with open(model_path, "w", encoding="utf-8") as f: 2157 YAML().dump(parameters["model"], f) 2158 features_name = parameters["general"]["feature_extraction"] 2159 features_path = os.path.join( 2160 self.project_path, "config", "features", f"{features_name}.yaml" 2161 ) 2162 with open(features_path, "w", encoding="utf-8") as f: 2163 YAML().dump(parameters["features"], f) 2164 aug_name = options.extractor_to_transformer[features_name] 2165 aug_path = os.path.join( 2166 self.project_path, "config", "augmentations", f"{aug_name}.yaml" 2167 ) 2168 with open(aug_path, "w", encoding="utf-8") as f: 2169 YAML().dump(parameters["augmentations"], f)
Update the parameters in the project config files
Parameters
parameters_update : dict, optional a dictionary of parameter updates load_search : str, optional the name of hyperparameter search results to load to config load_parameters : list, optional a list of lists of string names of the parameters to load from the searches round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two
2171 def get_summary( 2172 self, 2173 episode_names: list, 2174 method: str = "last", 2175 average: int = 1, 2176 metrics: List = None, 2177 ) -> Dict: 2178 """ 2179 Get a summary of episode statistics 2180 2181 If the episode has multiple runs, the statistics will be aggregated over all of them. 2182 2183 Parameters 2184 ---------- 2185 episode_name : str 2186 the name of the episode 2187 method : ["best", "last"] 2188 the method for choosing the epochs 2189 average : int, default 1 2190 the number of epochs to average over (for each run) 2191 metrics : list, optional 2192 a list of metrics 2193 2194 Returns 2195 ------- 2196 statistics : dict 2197 a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean 2198 and 'std' for the standard deviation 2199 """ 2200 2201 runs = [] 2202 for episode_name in episode_names: 2203 runs_ep = self._episodes().get_runs(episode_name) 2204 if len(runs_ep) == 0: 2205 raise RuntimeError( 2206 f"There is no {episode_name} episode in the project memory" 2207 ) 2208 runs += runs_ep 2209 if metrics is None: 2210 metrics = self._episode(runs[0]).get_metrics() 2211 2212 values = {m: [] for m in metrics} 2213 for run in runs: 2214 for m in metrics: 2215 log = self._episode(run).get_metric_log(mode="val", metric_name=m) 2216 if method == "best": 2217 log = sorted(log) 2218 values[m] += list(log[-average:]) 2219 elif method == "last": 2220 if len(log) == 0: 2221 episodes = self._episodes().data 2222 if average == 1 and ("results", m) in episodes.columns: 2223 values[m] += [episodes.loc[run, ("results", m)]] 2224 else: 2225 raise RuntimeError(f"Did not find {m} metric for {run} run") 2226 values[m] += list(log[-average:]) 2227 elif method.startswith("epoch"): 2228 epoch = int(method[5:]) - 1 2229 pars = self._episodes().load_parameters(run) 2230 step = int(pars["training"]["validation_interval"]) 2231 values[m] += [log[epoch // step]] 2232 else: 2233 raise ValueError( 2234 f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']" 2235 ) 2236 statistics = defaultdict(lambda: {}) 2237 for m, v in values.items(): 2238 statistics[m]["mean"] = np.mean(v) 2239 statistics[m]["std"] = np.std(v) 2240 print(f"SUMMARY {episode_names}") 2241 for m, v in statistics.items(): 2242 print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}') 2243 print("\n") 2244 return dict(statistics)
Get a summary of episode statistics
If the episode has multiple runs, the statistics will be aggregated over all of them.
Parameters
episode_name : str the name of the episode method : ["best", "last"] the method for choosing the epochs average : int, default 1 the number of epochs to average over (for each run) metrics : list, optional a list of metrics
Returns
statistics : dict a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean and 'std' for the standard deviation
2246 @staticmethod 2247 def remove_project(name: str, projects_path: str = None) -> None: 2248 """ 2249 Remove all project files and experiment records and results 2250 """ 2251 2252 if projects_path is None: 2253 projects_path = os.path.join(str(Path.home()), "DLC2Action") 2254 project_path = os.path.join(projects_path, name) 2255 if os.path.exists(project_path): 2256 shutil.rmtree(project_path)
Remove all project files and experiment records and results
2258 def remove_saved_features( 2259 self, 2260 dataset_names: List = None, 2261 exceptions: List = None, 2262 remove_active: bool = False, 2263 ) -> None: 2264 """ 2265 Remove saved pre-computed dataset files 2266 2267 By default, all pre-computed features will be deleted. 2268 No essential information can get lost, storing them only saves time. Be careful with deleting datasets 2269 while training or inference is happening though. 2270 2271 Parameters 2272 ---------- 2273 dataset_names : list, optional 2274 a list of dataset names to delete (by default all names are added) 2275 exceptions : list, optional 2276 a list of dataset names to not be deleted 2277 remove_active : bool, default False 2278 if `False`, datasets used by unfinished episodes will not be deleted 2279 """ 2280 2281 print("Removing datasets...") 2282 if dataset_names is None: 2283 dataset_names = [] 2284 if exceptions is None: 2285 exceptions = [] 2286 if not remove_active: 2287 exceptions += self._episodes().get_active_datasets() 2288 dataset_path = os.path.join(self.project_path, "saved_datasets") 2289 if os.path.exists(dataset_path): 2290 if dataset_names == []: 2291 dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)]) 2292 2293 to_remove = [ 2294 x 2295 for x in dataset_names 2296 if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions 2297 ] 2298 if len(to_remove) > 2: 2299 to_remove = tqdm(to_remove) 2300 for dataset in to_remove: 2301 shutil.rmtree(os.path.join(dataset_path, dataset)) 2302 to_remove = [ 2303 f"{x}.pickle" 2304 for x in dataset_names 2305 if os.path.exists(os.path.join(dataset_path, f"{x}.pickle")) 2306 and x not in exceptions 2307 ] 2308 for dataset in to_remove: 2309 os.remove(os.path.join(dataset_path, dataset)) 2310 names = self._saved_datasets().dataset_names() 2311 self._saved_datasets().remove(names) 2312 print("\n")
Remove saved pre-computed dataset files
By default, all pre-computed features will be deleted. No essential information can get lost, storing them only saves time. Be careful with deleting datasets while training or inference is happening though.
Parameters
dataset_names : list, optional
a list of dataset names to delete (by default all names are added)
exceptions : list, optional
a list of dataset names to not be deleted
remove_active : bool, default False
if False
, datasets used by unfinished episodes will not be deleted
2314 def remove_extra_checkpoints( 2315 self, episode_names: List = None, exceptions: List = None 2316 ) -> None: 2317 """ 2318 Remove intermediate model checkpoint files (only leave the results of the last epoch) 2319 2320 By default, all intermediate checkpoints will be deleted. 2321 Files in the model folder that are not associated with any record in the meta files are also deleted. 2322 2323 Parameters 2324 ---------- 2325 episode_names : list, optional 2326 a list of episode names to clean (by default all names are added) 2327 exceptions : list, optional 2328 a list of episode names to not clean 2329 """ 2330 2331 model_path = os.path.join(self.project_path, "results", "model") 2332 try: 2333 all_names = self._episodes().data.index 2334 except: 2335 all_names = os.listdir(model_path) 2336 if episode_names is None: 2337 episode_names = all_names 2338 if exceptions is None: 2339 exceptions = [] 2340 to_remove = [x for x in episode_names if x not in exceptions] 2341 folders = os.listdir(model_path) 2342 for folder in folders: 2343 if folder not in all_names: 2344 shutil.rmtree(os.path.join(model_path, folder)) 2345 elif folder in to_remove: 2346 files = os.listdir(os.path.join(model_path, folder)) 2347 for file in sorted(files)[:-1]: 2348 os.remove(os.path.join(model_path, folder, file))
Remove intermediate model checkpoint files (only leave the results of the last epoch)
By default, all intermediate checkpoints will be deleted. Files in the model folder that are not associated with any record in the meta files are also deleted.
Parameters
episode_names : list, optional a list of episode names to clean (by default all names are added) exceptions : list, optional a list of episode names to not clean
2350 def remove_search(self, search_name: str) -> None: 2351 """ 2352 Remove a hyperparameter search record 2353 2354 Parameters 2355 ---------- 2356 search_name : str 2357 the name of the search to remove 2358 """ 2359 2360 self._searches().remove_episode(search_name) 2361 graph_path = os.path.join(self.project_path, "results", "searches", search_name) 2362 if os.path.exists(graph_path): 2363 shutil.rmtree(graph_path)
Remove a hyperparameter search record
Parameters
search_name : str the name of the search to remove
2365 def remove_prediction(self, prediction_name: str) -> None: 2366 """ 2367 Remove a prediction record 2368 2369 Parameters 2370 ---------- 2371 prediction_name : str 2372 the name of the prediction to remove 2373 """ 2374 2375 self._predictions().remove_episode(prediction_name) 2376 prediction_path = os.path.join( 2377 self.project_path, "results", "predictions", prediction_name 2378 ) 2379 if os.path.exists(prediction_path): 2380 shutil.rmtree(prediction_path)
Remove a prediction record
Parameters
prediction_name : str the name of the prediction to remove
2382 def remove_episode(self, episode_name: str) -> None: 2383 """ 2384 Remove all model, logs and metafile records related to an episode 2385 2386 Parameters 2387 ---------- 2388 episode_name : str 2389 the name of the episode to remove 2390 """ 2391 2392 runs = self._episodes().get_runs(episode_name) 2393 runs.append(episode_name) 2394 for run in runs: 2395 self._episodes().remove_episode(run) 2396 model_path = os.path.join(self.project_path, "results", "model", run) 2397 if os.path.exists(model_path): 2398 shutil.rmtree(model_path) 2399 log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt") 2400 if os.path.exists(log_path): 2401 os.remove(log_path)
Remove all model, logs and metafile records related to an episode
Parameters
episode_name : str the name of the episode to remove
2403 def prune_unfinished(self, exceptions: List = None) -> None: 2404 """ 2405 Remove all interrupted episodes 2406 2407 Remove all episodes that either don't have a log file or have less epochs in the log file than in 2408 the training parameters or have a model folder but not a record. Note that it can remove episodes that are 2409 currently running! 2410 2411 Parameters 2412 ---------- 2413 exceptions : list 2414 the episodes to keep even if they are interrupted 2415 2416 Returns 2417 ------- 2418 pruned : list 2419 a list of the episode names that were pruned 2420 """ 2421 2422 if exceptions is None: 2423 exceptions = [] 2424 unfinished = self._episodes().unfinished_episodes() 2425 unfinished = [x for x in unfinished if x not in exceptions] 2426 model_folders = os.listdir(os.path.join(self.project_path, "results", "model")) 2427 unfinished += [ 2428 x for x in model_folders if x not in self._episodes().list_episodes().index 2429 ] 2430 print(f"PRUNING {unfinished}") 2431 for episode_name in unfinished: 2432 self.remove_episode(episode_name) 2433 print(f"\n") 2434 return unfinished
Remove all interrupted episodes
Remove all episodes that either don't have a log file or have less epochs in the log file than in the training parameters or have a model folder but not a record. Note that it can remove episodes that are currently running!
Parameters
exceptions : list the episodes to keep even if they are interrupted
Returns
pruned : list a list of the episode names that were pruned
2436 def prediction_path(self, prediction_name: str) -> str: 2437 """ 2438 Get the path where prediction files are saved 2439 2440 Parameters 2441 ---------- 2442 prediction_name : str 2443 name of the prediction 2444 2445 Returns 2446 ------- 2447 prediction_path : str 2448 the file path 2449 """ 2450 2451 return os.path.join( 2452 self.project_path, "results", "predictions", f"{prediction_name}" 2453 )
Get the path where prediction files are saved
Parameters
prediction_name : str name of the prediction
Returns
prediction_path : str the file path
2469 @staticmethod 2470 def data_types() -> List: 2471 """ 2472 Get available data types 2473 2474 Returns 2475 ------- 2476 list 2477 available data types 2478 """ 2479 2480 return options.input_stores
Get available data types
Returns
list available data types
2482 @staticmethod 2483 def annotation_types() -> List: 2484 """ 2485 Get available annotation types 2486 2487 Returns 2488 ------- 2489 list 2490 available annotation types 2491 """ 2492 2493 return options.annotation_stores
Get available annotation types
Returns
list available annotation types
3075 def set_main_parameters(self, model_name: str = None, metric_names: List = None): 3076 """ 3077 Select the model and the metrics 3078 3079 Parameters 3080 ---------- 3081 model_name : str, optional 3082 model name; run `project.help("model") to find out more 3083 metric_names : list, optional 3084 a list of metric function names; run `project.help("metrics") to find out more 3085 """ 3086 3087 pars = {"general": {}} 3088 if model_name is not None: 3089 assert model_name in options.models 3090 pars["general"]["model_name"] = model_name 3091 if metric_names is not None: 3092 for metric in metric_names: 3093 assert metric in options.metrics 3094 pars["general"]["metric_functions"] = metric_names 3095 self.update_parameters(pars)
Select the model and the metrics
Parameters
model_name : str, optional
model name; run project.help("model") to find out more
metric_names : list, optional
a list of metric function names; run
project.help("metrics") to find out more
3097 def help(self, keyword: str = None): 3098 """ 3099 Get information on available options 3100 3101 Parameters 3102 ---------- 3103 keyword : str, optional 3104 the keyword for options (run without arguments to see which keywords are available) 3105 3106 """ 3107 3108 if keyword is None: 3109 print("AVAILABLE HELP FUNCTIONS:") 3110 print("- Try running `project.help(keyword)` with the following keywords:") 3111 print(" - model: to get more information on available models,") 3112 print( 3113 " - features: to get more information on available feature extraction modes," 3114 ) 3115 print( 3116 " - partition_method: to get more information on available train/test/val partitioning methods," 3117 ) 3118 print(" - metrics: to see a list of available metric functions.") 3119 print(" - data: to see help for expected data structure") 3120 print( 3121 "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in." 3122 ) 3123 print( 3124 "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify" 3125 ) 3126 print( 3127 f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)." 3128 ) 3129 elif keyword == "model": 3130 print("MODELS:") 3131 for key, model in options.models.items(): 3132 print(f"{key}:") 3133 print(model.__doc__) 3134 elif keyword == "features": 3135 print("FEATURE EXTRACTORS:") 3136 for key, extractor in options.feature_extractors.items(): 3137 print(f"{key}:") 3138 print(extractor.__doc__) 3139 elif keyword == "partition_method": 3140 print("PARTITION METHODS:") 3141 print( 3142 BehaviorDataset.partition_train_test_val.__doc__.split( 3143 "The partitioning method:" 3144 )[1].split("val_frac :")[0] 3145 ) 3146 elif keyword == "metrics": 3147 print("METRICS:") 3148 for key, metric in options.metrics.items(): 3149 print(f"{key}:") 3150 print(metric.__doc__) 3151 elif keyword == "data": 3152 print("DATA:") 3153 print(f"Video data: {self.data_type}") 3154 print(options.input_stores[self.data_type].__doc__) 3155 print(f"Annotation data: {self.annotation_type}") 3156 print(options.annotation_stores[self.annotation_type].__doc__) 3157 print( 3158 "Annotation path and data path don't have to be separate, you can keep everything in one folder." 3159 ) 3160 else: 3161 raise ValueError(f"The {keyword} keyword is not recognized") 3162 print("\n")
Get information on available options
Parameters
keyword : str, optional the keyword for options (run without arguments to see which keywords are available)
3182 def list_blanks(self, blanks=None): 3183 """ 3184 List parameters that need to be filled in 3185 3186 Parameters 3187 ---------- 3188 blanks : list, optional 3189 a list of the parameters to list, if already known 3190 """ 3191 3192 if blanks is None: 3193 blanks = self._get_blanks() 3194 if len(blanks) > 0: 3195 to_update = defaultdict(lambda: []) 3196 for b, k, c in blanks: 3197 to_update[b].append((k, c)) 3198 print("Before running experiments, please update all the blanks.") 3199 print("To do that, you can run this.") 3200 print("--------------------------------------------------------") 3201 print(f"project.update_parameters(") 3202 print(f" {{") 3203 for big_key, keys in to_update.items(): 3204 print(f' "{big_key}": {{') 3205 for key, comment in keys: 3206 print(f' "{key}": ..., {comment}') 3207 print(f" }},") 3208 print(f" }}") 3209 print(")") 3210 print("--------------------------------------------------------") 3211 print("Replace ... with relevant values.") 3212 else: 3213 print("There is no blanks left!")
List parameters that need to be filled in
Parameters
blanks : list, optional a list of the parameters to list, if already known
3215 def list_basic_parameters( 3216 self, 3217 ): 3218 """ 3219 Get a list of most relevant parameters and code to modify them 3220 """ 3221 3222 parameters = self._read_parameters() 3223 print("BASIC PARAMETERS:") 3224 model_name = parameters["general"]["model_name"] 3225 metric_names = parameters["general"]["metric_functions"] 3226 loss_name = parameters["general"]["loss_function"] 3227 feature_extraction = parameters["general"]["feature_extraction"] 3228 print("Here is a list of current parameters.") 3229 print( 3230 "You can copy this code, change the parameters you want to set and run it to update the project config." 3231 ) 3232 print("--------------------------------------------------------") 3233 print("project.update_parameters(") 3234 print(" {") 3235 for group in ["general", "data", "training"]: 3236 print(f' "{group}": {{') 3237 for key in options.basic_parameters[group]: 3238 if key in parameters[group]: 3239 print( 3240 f' "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}' 3241 ) 3242 print(" },") 3243 print(' "losses": {') 3244 print(f' "{loss_name}": {{') 3245 for key in options.basic_parameters["losses"][loss_name]: 3246 if key in parameters["losses"][loss_name]: 3247 print( 3248 f' "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}' 3249 ) 3250 print(" },") 3251 print(" },") 3252 print(' "metrics": {') 3253 for metric in metric_names: 3254 print(f' "{metric}": {{') 3255 for key in parameters["metrics"][metric]: 3256 print( 3257 f' "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}' 3258 ) 3259 print(" },") 3260 print(" },") 3261 print(' "model": {') 3262 for key in options.basic_parameters["model"][model_name]: 3263 if key in parameters["model"]: 3264 print( 3265 f' "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}' 3266 ) 3267 3268 print(" },") 3269 print(' "features": {') 3270 for key in options.basic_parameters["features"][feature_extraction]: 3271 if key in parameters["features"]: 3272 print( 3273 f' "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}' 3274 ) 3275 3276 print(" },") 3277 print(' "augmentations": {') 3278 for key in options.basic_parameters["augmentations"][feature_extraction]: 3279 if key in parameters["augmentations"]: 3280 print( 3281 f' "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}' 3282 ) 3283 print(" },") 3284 print(" },") 3285 print(")") 3286 print("--------------------------------------------------------") 3287 print("\n")
Get a list of most relevant parameters and code to modify them
3788 def count_classes( 3789 self, 3790 load_episode: str = None, 3791 parameters_update: Dict = None, 3792 remove_saved_features: bool = False, 3793 bouts: bool = True, 3794 ) -> Dict: 3795 """ 3796 Get a dictionary of class counts in different modes 3797 3798 Parameters 3799 ---------- 3800 load_episode : str, optional 3801 the episode settings to load 3802 parameters_update : dict, optional 3803 a dictionary of parameter updates (only for "data" and "general" categories) 3804 remove_saved_features : bool, default False 3805 if `True`, the dataset that is used for computation is then deleted 3806 bouts : bool, default False 3807 if `True`, instead of frame counts segment counts are returned 3808 3809 Returns 3810 ------- 3811 class_counts : dict 3812 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 3813 class names and values are class counts (in frames) 3814 """ 3815 3816 if load_episode is None: 3817 task, parameters = self._make_task_training( 3818 episode_name="_", parameters_update=parameters_update, throwaway=True 3819 ) 3820 else: 3821 task, parameters, _ = self._make_task_prediction( 3822 "_", 3823 load_episode=load_episode, 3824 parameters_update=parameters_update, 3825 ) 3826 class_counts = task.count_classes(bouts=bouts) 3827 behaviors = task.behaviors_dict() 3828 class_counts = { 3829 kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()} 3830 for kk, vv in class_counts.items() 3831 } 3832 if remove_saved_features: 3833 self._remove_stores(parameters) 3834 return class_counts
Get a dictionary of class counts in different modes
Parameters
load_episode : str, optional
the episode settings to load
parameters_update : dict, optional
a dictionary of parameter updates (only for "data" and "general" categories)
remove_saved_features : bool, default False
if True
, the dataset that is used for computation is then deleted
bouts : bool, default False
if True
, instead of frame counts segment counts are returned
Returns
class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)
3836 def plot_class_distribution( 3837 self, 3838 parameters_update: Dict = None, 3839 frame_cutoff: int = 1, 3840 bout_cutoff: int = 1, 3841 print_full: bool = False, 3842 remove_saved_features: bool = False, 3843 ) -> None: 3844 """ 3845 Make a class distribution plot 3846 3847 You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset 3848 is created or laoded for the computation with the default parameters). 3849 3850 Parameters 3851 ---------- 3852 parameters_update : dict, optional 3853 a dictionary of parameter updates (only for "data" and "general" categories) 3854 remove_saved_features : bool, default False 3855 if `True`, the dataset that is used for computation is then deleted 3856 """ 3857 3858 task, parameters = self._make_task_training( 3859 episode_name="_", parameters_update=parameters_update, throwaway=True 3860 ) 3861 cutoff = {True: bout_cutoff, False: frame_cutoff} 3862 for bouts in [True, False]: 3863 class_counts = task.count_classes(bouts=bouts) 3864 if print_full: 3865 print("Bouts:" if bouts else "Frames:") 3866 for k, v in class_counts.items(): 3867 if sum(v.values()) != 0: 3868 print(f" {k}:") 3869 values, keys = zip( 3870 *[ 3871 x 3872 for x in sorted(zip(v.values(), v.keys()), reverse=True) 3873 if x[-1] != -100 3874 ] 3875 ) 3876 for kk, vv in zip(keys, values): 3877 print(f" {task.behaviors_dict()[kk]}: {vv}") 3878 class_counts = { 3879 kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]} 3880 for kk, vv in class_counts.items() 3881 } 3882 for key, d in class_counts.items(): 3883 if sum(d.values()) != 0: 3884 values, keys = zip( 3885 *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100] 3886 ) 3887 keys = [task.behaviors_dict()[x] for x in keys] 3888 plt.bar(keys, values) 3889 plt.title(key) 3890 plt.xticks(rotation=45, ha="right") 3891 if bouts: 3892 plt.ylabel("bouts") 3893 else: 3894 plt.ylabel("frames") 3895 plt.tight_layout() 3896 plt.show() 3897 if remove_saved_features: 3898 self._remove_stores(parameters)
Make a class distribution plot
You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset is created or laoded for the computation with the default parameters).
Parameters
parameters_update : dict, optional
a dictionary of parameter updates (only for "data" and "general" categories)
remove_saved_features : bool, default False
if True
, the dataset that is used for computation is then deleted
4338 def plot_confusion_matrix( 4339 self, 4340 episode_name: str, 4341 load_epoch: int = None, 4342 parameters_update: Dict = None, 4343 type: str = "recall", 4344 mode: str = "val", 4345 remove_saved_features: bool = False, 4346 ) -> Tuple[ndarray, Iterable]: 4347 """ 4348 Make a confusion matrix plot and return the data 4349 4350 If the annotation is non-exclusive, only false positive labels are considered. 4351 4352 Parameters 4353 ---------- 4354 episode_name : str 4355 the name of the episode to load 4356 load_epoch : int, optional 4357 the index of the epoch to load (by default the last one is loaded) 4358 parameters_update : dict, optional 4359 a dictionary of parameter updates (only for "data" and "general" categories) 4360 mode : {'val', 'all', 'test', 'train'} 4361 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 4362 type : {"recall", "precision"} 4363 for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken 4364 into account, and if `type` is `"precision"`, only false negatives 4365 remove_saved_features : bool, default False 4366 if `True`, the dataset that is used for computation is then deleted 4367 4368 Returns 4369 ------- 4370 confusion_matrix : np.ndarray 4371 a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of 4372 frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, 4373 `N_i` is the number of frames that have the i-th label in the ground truth 4374 classes : list 4375 a list of labels 4376 """ 4377 4378 task, parameters, mode = self._make_task_prediction( 4379 "_", 4380 load_episode=episode_name, 4381 load_epoch=load_epoch, 4382 parameters_update=parameters_update, 4383 mode=mode, 4384 ) 4385 dataset = task.dataset(mode) 4386 prediction = task.predict(dataset, raw_output=True) 4387 confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type) 4388 if remove_saved_features: 4389 self._remove_stores(parameters) 4390 fig, ax = plt.subplots(figsize=(len(classes), len(classes))) 4391 ax.imshow(confusion_matrix) 4392 # Show all ticks and label them with the respective list entries 4393 ax.set_xticks(np.arange(len(classes))) 4394 ax.set_xticklabels(classes) 4395 ax.set_yticks(np.arange(len(classes))) 4396 ax.set_yticklabels(classes) 4397 # Rotate the tick labels and set their alignment. 4398 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 4399 # Loop over data dimensions and create text annotations. 4400 for i in range(len(classes)): 4401 for j in range(len(classes)): 4402 ax.text( 4403 j, 4404 i, 4405 np.round(confusion_matrix[i, j], 2), 4406 ha="center", 4407 va="center", 4408 color="w", 4409 ) 4410 if type is not None: 4411 ax.set_title(f"{type} {episode_name}") 4412 else: 4413 ax.set_title(episode_name) 4414 fig.tight_layout() 4415 plt.show() 4416 return confusion_matrix, classes
Make a confusion matrix plot and return the data
If the annotation is non-exclusive, only false positive labels are considered.
Parameters
episode_name : str
the name of the episode to load
load_epoch : int, optional
the index of the epoch to load (by default the last one is loaded)
parameters_update : dict, optional
a dictionary of parameter updates (only for "data" and "general" categories)
mode : {'val', 'all', 'test', 'train'}
the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
type : {"recall", "precision"}
for datasets with non-exclusive annotation, if type
is "recall"
, only false positives are taken
into account, and if type
is "precision"
, only false negatives
remove_saved_features : bool, default False
if True
, the dataset that is used for computation is then deleted
Returns
confusion_matrix : np.ndarray
a confusion matrix of shape (#classes, #classes)
where A[i, j] = F_ij/N_i
, F_ij
is the number of
frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
N_i
is the number of frames that have the i-th label in the ground truth
classes : list
a list of labels
4418 def plot_predictions( 4419 self, 4420 episode_name: str, 4421 load_epoch: int = None, 4422 parameters_update: Dict = None, 4423 add_legend: bool = True, 4424 ground_truth: bool = True, 4425 colormap: str = "viridis", 4426 hide_axes: bool = False, 4427 min_classes: int = 1, 4428 width: float = 10, 4429 whole_video: bool = False, 4430 transparent: bool = False, 4431 drop_classes: Set = None, 4432 search_classes: Set = None, 4433 num_plots: int = 1, 4434 remove_saved_features: bool = False, 4435 smooth_interval_prediction: int = 0, 4436 data_path: str = None, 4437 file_paths: Set = None, 4438 mode: str = "val", 4439 behavior_name: str = None, 4440 ) -> None: 4441 """ 4442 Visualize random predictions 4443 4444 Parameters 4445 ---------- 4446 episode_name : str 4447 the name of the episode to load 4448 load_epoch : int, optional 4449 the epoch to load (by default last) 4450 parameters_update : dict, optional 4451 parameter update dictionary 4452 add_legend : bool, default True 4453 if True, legend will be added to the plot 4454 ground_truth : bool, default True 4455 if True, ground truth will be added to the plot 4456 colormap : str, default 'Accent' 4457 the `matplotlib` colormap to use 4458 hide_axes : bool, default True 4459 if `True`, the axes will be hidden on the plot 4460 min_classes : int, default 1 4461 the minimum number of classes in a displayed interval 4462 width : float, default 10 4463 the width of the plot 4464 whole_video : bool, default False 4465 if `True`, whole videos are plotted instead of segments 4466 transparent : bool, default False 4467 if `True`, the background on the plot is transparent 4468 drop_classes : set, optional 4469 a set of class names to not be displayed 4470 search_classes : set, optional 4471 if given, only intervals where at least one of the classes is in ground truth will be shown 4472 num_plots : int, default 1 4473 the number of plots to make 4474 remove_saved_features : bool, default False 4475 if `True`, the dataset will be deleted after computation 4476 smooth_interval_prediction : int, default 0 4477 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) 4478 data_path : str, optional 4479 the data path to run the prediction for 4480 mode : {'all', 'test', 'val', 'train'} 4481 the subset of the data to make the prediction for (forced to 'all' if data_path is not None) 4482 file_paths : set, optional 4483 a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction 4484 for 4485 behavior_name : str, optional 4486 for non-exclusive classificaton datasets, choose which behavior to visualize (by default first in list) 4487 """ 4488 4489 other_path = os.path.join(self.project_path, "results", "other") 4490 task, parameters, mode = self._make_task_prediction( 4491 "_", 4492 load_episode=episode_name, 4493 parameters_update=parameters_update, 4494 load_epoch=load_epoch, 4495 data_path=data_path, 4496 file_paths=file_paths, 4497 mode=mode, 4498 ) 4499 if not os.path.exists(other_path): 4500 os.mkdir(other_path) 4501 for i in range(num_plots): 4502 task.visualize_results( 4503 save_path=os.path.join( 4504 other_path, f"{episode_name}_prediction_{i}.jpg" 4505 ), 4506 add_legend=add_legend, 4507 ground_truth=ground_truth, 4508 colormap=colormap, 4509 hide_axes=hide_axes, 4510 min_classes=min_classes, 4511 whole_video=whole_video, 4512 transparent=transparent, 4513 dataset=mode, 4514 drop_classes=drop_classes, 4515 search_classes=search_classes, 4516 width=width, 4517 smooth_interval_prediction=smooth_interval_prediction, 4518 behavior_name=behavior_name, 4519 ) 4520 if remove_saved_features: 4521 self._remove_stores(parameters)
Visualize random predictions
Parameters
episode_name : str
the name of the episode to load
load_epoch : int, optional
the epoch to load (by default last)
parameters_update : dict, optional
parameter update dictionary
add_legend : bool, default True
if True, legend will be added to the plot
ground_truth : bool, default True
if True, ground truth will be added to the plot
colormap : str, default 'Accent'
the matplotlib
colormap to use
hide_axes : bool, default True
if True
, the axes will be hidden on the plot
min_classes : int, default 1
the minimum number of classes in a displayed interval
width : float, default 10
the width of the plot
whole_video : bool, default False
if True
, whole videos are plotted instead of segments
transparent : bool, default False
if True
, the background on the plot is transparent
drop_classes : set, optional
a set of class names to not be displayed
search_classes : set, optional
if given, only intervals where at least one of the classes is in ground truth will be shown
num_plots : int, default 1
the number of plots to make
remove_saved_features : bool, default False
if True
, the dataset will be deleted after computation
smooth_interval_prediction : int, default 0
if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
data_path : str, optional
the data path to run the prediction for
mode : {'all', 'test', 'val', 'train'}
the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
file_paths : set, optional
a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
for
behavior_name : str, optional
for non-exclusive classificaton datasets, choose which behavior to visualize (by default first in list)
4523 def create_metadata_backup(self) -> None: 4524 """ 4525 Create a copy of the meta files 4526 """ 4527 4528 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 4529 meta_path = os.path.join(self.project_path, "meta") 4530 if os.path.exists(meta_copy_path): 4531 shutil.rmtree(meta_copy_path) 4532 os.mkdir(meta_copy_path) 4533 for file in os.listdir(meta_path): 4534 if file == "backup": 4535 continue 4536 shutil.copy( 4537 os.path.join(meta_path, file), os.path.join(meta_copy_path, file) 4538 )
Create a copy of the meta files
4540 def load_metadata_backup(self) -> None: 4541 """ 4542 Load from previously created meta data backup (in case of corruption) 4543 """ 4544 4545 meta_copy_path = os.path.join(self.project_path, "meta", "backup") 4546 meta_path = os.path.join(self.project_path, "meta") 4547 for file in os.listdir(meta_copy_path): 4548 shutil.copy( 4549 os.path.join(meta_copy_path, file), os.path.join(meta_path, file) 4550 )
Load from previously created meta data backup (in case of corruption)
4552 def get_behavior_dictionary(self, episode_name: str) -> Dict: 4553 """ 4554 Get the behavior dictionary for an episode 4555 4556 Parameters 4557 ---------- 4558 episode_name : str 4559 the name of the episode 4560 4561 Returns 4562 ------- 4563 behaviors_dictionary : dict 4564 a dictionary where keys are label indices and values are label names 4565 """ 4566 4567 run = self._episodes().get_runs(episode_name)[0] 4568 return self._episode(run).get_behaviors_dict()
Get the behavior dictionary for an episode
Parameters
episode_name : str the name of the episode
Returns
behaviors_dictionary : dict a dictionary where keys are label indices and values are label names
4570 def import_episodes( 4571 self, 4572 episodes_directory: str, 4573 name_map: Dict = None, 4574 repeat_policy: str = "error", 4575 ) -> None: 4576 """ 4577 Import episodes exported with `Project.export_episodes` 4578 4579 Parameters 4580 ---------- 4581 episodes_directory : str 4582 the path to the exported episodes directory 4583 name_map : dict 4584 a name change dictionary for the episodes: keys are old names, values are new names 4585 """ 4586 4587 if name_map is None: 4588 name_map = {} 4589 episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle")) 4590 to_remove = [] 4591 import_string = "Imported episodes: " 4592 for episode_name in episodes.index: 4593 if episode_name in name_map: 4594 import_string += f"{episode_name} " 4595 episode_name = name_map[episode_name] 4596 import_string += f"({episode_name}), " 4597 else: 4598 import_string += f"{episode_name}, " 4599 try: 4600 self._check_episode_validity(episode_name, allow_doublecolon=True) 4601 except ValueError as e: 4602 if str(e).endswith("is already taken!"): 4603 if repeat_policy == "skip": 4604 to_remove.append(episode_name) 4605 elif repeat_policy == "force": 4606 self.remove_episode(episode_name) 4607 elif repeat_policy == "error": 4608 raise ValueError( 4609 f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it" 4610 ) 4611 else: 4612 raise ValueError( 4613 f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' ans 'error']" 4614 ) 4615 episodes = episodes.drop(index=to_remove) 4616 self._episodes().update( 4617 episodes, 4618 name_map=name_map, 4619 force=(repeat_policy == "force"), 4620 data_path=self.data_path, 4621 annotation_path=self.annotation_path, 4622 ) 4623 for episode_name in episodes.index: 4624 if episode_name in name_map: 4625 new_episode_name = name_map[episode_name] 4626 else: 4627 new_episode_name = episode_name 4628 model_dir = os.path.join( 4629 self.project_path, "results", "model", new_episode_name 4630 ) 4631 old_model_dir = os.path.join(episodes_directory, "model", episode_name) 4632 if os.path.exists(model_dir): 4633 shutil.rmtree(model_dir) 4634 os.mkdir(model_dir) 4635 for file in os.listdir(old_model_dir): 4636 shutil.copyfile( 4637 os.path.join(old_model_dir, file), os.path.join(model_dir, file) 4638 ) 4639 log_file = os.path.join( 4640 self.project_path, "results", "logs", f"{new_episode_name}.txt" 4641 ) 4642 old_log_file = os.path.join( 4643 episodes_directory, "logs", f"{episode_name}.txt" 4644 ) 4645 shutil.copyfile(old_log_file, log_file) 4646 print(import_string) 4647 print("\n")
Import episodes exported with Project.export_episodes
Parameters
episodes_directory : str the path to the exported episodes directory name_map : dict a name change dictionary for the episodes: keys are old names, values are new names
4649 def export_episodes( 4650 self, episode_names: List, output_directory: str, name: str = None 4651 ) -> None: 4652 """ 4653 Save selected episodes as a file that can be imported into another project with `Project.import_episodes` 4654 4655 Parameters 4656 ---------- 4657 episode_names : list 4658 a list of string episode names 4659 output_directory : str 4660 the path to the directory where the episodes will be saved 4661 name : str, optional 4662 the name of the episodes directory (by default `exported_episodes`) 4663 """ 4664 4665 if name is None: 4666 name = "exported_episodes" 4667 if os.path.exists( 4668 os.path.join(output_directory, name + ".zip") 4669 ) or os.path.exists(os.path.join(output_directory, name)): 4670 i = 1 4671 while os.path.exists( 4672 os.path.join(output_directory, name + f"_{i}.zip") 4673 ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")): 4674 i += 1 4675 name = name + f"_{i}" 4676 dest_dir = os.path.join(output_directory, name) 4677 os.mkdir(dest_dir) 4678 os.mkdir(os.path.join(dest_dir, "model")) 4679 os.mkdir(os.path.join(dest_dir, "logs")) 4680 runs = [] 4681 for episode in episode_names: 4682 runs += self._episodes().get_runs(episode) 4683 for run in runs: 4684 shutil.copytree( 4685 os.path.join(self.project_path, "results", "model", run), 4686 os.path.join(dest_dir, "model", run), 4687 ) 4688 shutil.copyfile( 4689 os.path.join(self.project_path, "results", "logs", f"{run}.txt"), 4690 os.path.join(dest_dir, "logs", f"{run}.txt"), 4691 ) 4692 data = self._episodes().get_subset(runs) 4693 data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))
Save selected episodes as a file that can be imported into another project with Project.import_episodes
Parameters
episode_names : list
a list of string episode names
output_directory : str
the path to the directory where the episodes will be saved
name : str, optional
the name of the episodes directory (by default exported_episodes
)
4695 def get_results_table( 4696 self, 4697 episode_names: List, 4698 metrics: List = None, 4699 include_std: bool = False, 4700 classes: List = None, 4701 ): 4702 """ 4703 Genererate a `pandas` dataframe with a summary of episode results 4704 4705 Parameters 4706 ---------- 4707 episode_names : list 4708 a list of names of episodes to include 4709 metrics : list, optional 4710 a list of metric names to include 4711 include_std : bool, default False 4712 if `True`, for episodes with multiple runs the mean and standard deviation will be displayed; 4713 otherwise only mean 4714 classes : list, optional 4715 a list of names of classes to include (by default all are included) 4716 4717 Returns 4718 ------- 4719 results : pd.DataFrame 4720 a table with the results 4721 """ 4722 4723 run_names = [] 4724 for episode in episode_names: 4725 run_names += self._episodes().get_runs(episode) 4726 episodes = self.list_episodes(run_names, print_results=False) 4727 metric_columns = [x for x in episodes.columns if x[0] == "results"] 4728 results_df = pd.DataFrame() 4729 if metrics is not None: 4730 metric_columns = [ 4731 x for x in metric_columns if x[1].split("_")[0] in metrics 4732 ] 4733 for episode in episode_names: 4734 results = [] 4735 metric_set = set() 4736 for run in self._episodes().get_runs(episode): 4737 beh_dict = self.get_behavior_dictionary(run) 4738 res_dict = defaultdict(lambda: {}) 4739 for column in metric_columns: 4740 if np.isnan(episodes.loc[run, column]): 4741 continue 4742 split = column[1].split("_") 4743 if split[-1].isnumeric(): 4744 beh_ind = int(split[-1]) 4745 metric_name = "_".join(split[:-1]) 4746 beh = beh_dict[beh_ind] 4747 else: 4748 beh = "average" 4749 metric_name = column[1] 4750 res_dict[beh][metric_name] = episodes.loc[run, column] 4751 metric_set.add(metric_name) 4752 if "average" not in res_dict: 4753 res_dict["average"] = {} 4754 for metric in metric_set: 4755 if metric not in res_dict["average"]: 4756 arr = [ 4757 res_dict[beh][metric] 4758 for beh in res_dict 4759 if metric in res_dict[beh] 4760 ] 4761 res_dict["average"][metric] = np.mean(arr) 4762 results.append(res_dict) 4763 episode_results = {} 4764 for metric in metric_set: 4765 for beh in results[0].keys(): 4766 if classes is not None and beh not in classes: 4767 continue 4768 arr = [] 4769 for res_dict in results: 4770 if metric in res_dict[beh]: 4771 arr.append(res_dict[beh][metric]) 4772 if len(arr) > 0: 4773 if include_std: 4774 episode_results[ 4775 (beh, f"{episode} {metric} mean") 4776 ] = np.mean(arr) 4777 episode_results[(beh, f"{episode} {metric} std")] = np.std( 4778 arr 4779 ) 4780 else: 4781 episode_results[(beh, f"{episode} {metric}")] = np.mean(arr) 4782 for key, value in episode_results.items(): 4783 results_df.loc[key[0], key[1]] = value 4784 print(f"RESULTS:") 4785 print(results_df) 4786 print("\n") 4787 return results_df
Genererate a pandas
dataframe with a summary of episode results
Parameters
episode_names : list
a list of names of episodes to include
metrics : list, optional
a list of metric names to include
include_std : bool, default False
if True
, for episodes with multiple runs the mean and standard deviation will be displayed;
otherwise only mean
classes : list, optional
a list of names of classes to include (by default all are included)
Returns
results : pd.DataFrame a table with the results
4789 def episode_exists(self, episode_name: str) -> bool: 4790 """ 4791 Check if an episode already exists 4792 4793 Parameters 4794 ---------- 4795 episode_name : str 4796 the episode name 4797 4798 Returns 4799 ------- 4800 exists : bool 4801 `True` if the episode exists 4802 """ 4803 4804 return self._episodes().check_name_validity(episode_name)
Check if an episode already exists
Parameters
episode_name : str the episode name
Returns
exists : bool
True
if the episode exists
4806 def search_exists(self, search_name: str) -> bool: 4807 """ 4808 Check if a search already exists 4809 4810 Parameters 4811 ---------- 4812 search_name : str 4813 the search name 4814 4815 Returns 4816 ------- 4817 exists : bool 4818 `True` if the search exists 4819 """ 4820 4821 return self._searches().check_name_validity(search_name)
Check if a search already exists
Parameters
search_name : str the search name
Returns
exists : bool
True
if the search exists
4823 def prediction_exists(self, prediction_name: str) -> bool: 4824 """ 4825 Check if a prediction already exists 4826 4827 Parameters 4828 ---------- 4829 prediction_name : str 4830 the prediction name 4831 4832 Returns 4833 ------- 4834 exists : bool 4835 `True` if the prediction exists 4836 """ 4837 4838 return self._predictions().check_name_validity(prediction_name)
Check if a prediction already exists
Parameters
prediction_name : str the prediction name
Returns
exists : bool
True
if the prediction exists
4853 def rename_episode(self, episode_name: str, new_episode_name: str): 4854 shutil.move( 4855 os.path.join(self.project_path, "results", "model", episode_name), 4856 os.path.join(self.project_path, "results", "model", new_episode_name), 4857 ) 4858 shutil.move( 4859 os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"), 4860 os.path.join( 4861 self.project_path, "results", "logs", f"{new_episode_name}.txt" 4862 ), 4863 ) 4864 self._episodes().rename_episode(episode_name, new_episode_name)