dlc2action.project.meta
Handling meta (history) files
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Handling meta (history) files 8""" 9 10import os 11from typing import Dict, List, Tuple, Set, Union 12import pandas as pd 13from time import localtime, strftime 14from copy import deepcopy 15import numpy as np 16from collections import defaultdict 17import warnings 18from dlc2action.utils import correct_path 19 20 21class Run: 22 """ 23 A class that manages operations with a single episode record 24 """ 25 26 def __init__( 27 self, 28 episode_name: str, 29 project_path: str, 30 meta_path: str = None, 31 params: Dict = None, 32 ): 33 """ 34 Parameters 35 ---------- 36 episode_name : str 37 the name of the episode 38 meta_path : str, optional 39 the path to the pickled SavedRuns dataframe 40 params : dict, optional 41 alternative to meta_path: pre-loaded pandas Series of episode parameters 42 """ 43 44 self.name = episode_name 45 self.project_path = project_path 46 if meta_path is not None: 47 try: 48 self.params = pd.read_pickle(meta_path).loc[episode_name] 49 except: 50 raise ValueError(f"The {episode_name} episode does not exist!") 51 elif params is not None: 52 self.params = params 53 else: 54 raise ValueError("Either meta_path or params has to be not None") 55 56 def training_time(self) -> int: 57 """ 58 Get the training time in seconds 59 60 Returns 61 ------- 62 training_time : int 63 the training time in seconds 64 """ 65 66 time_str = self.params["meta"].get("training_time") 67 try: 68 if time_str is None or np.isnan(time_str): 69 return np.nan 70 except TypeError: 71 pass 72 h, m, s = time_str.split(":") 73 seconds = int(h) * 3600 + int(m) * 60 + int(s) 74 return seconds 75 76 def model_file(self, load_epoch: int = None) -> str: 77 """ 78 Get a checkpoint file path 79 80 Parameters 81 ---------- 82 project_path : str 83 the current project folder path 84 load_epoch : int, optional 85 the epoch to load (the closest checkpoint will be chosen; if not given will be set to last) 86 87 Returns 88 ------- 89 checkpoint_path : str 90 the path to the checkpoint 91 """ 92 93 model_path = correct_path( 94 self.params["training"]["model_save_path"], self.project_path 95 ) 96 if load_epoch is None: 97 model_file = sorted(os.listdir(model_path))[-1] 98 else: 99 model_files = os.listdir(model_path) 100 if len(model_files) == 0: 101 model_file = None 102 else: 103 epochs = [int(file[5:].split(".")[0]) for file in model_files] 104 diffs = [np.abs(epoch - load_epoch) for epoch in epochs] 105 argmin = np.argmin(diffs) 106 model_file = model_files[argmin] 107 model_file = os.path.join(model_path, model_file) 108 return model_file 109 110 def dataset_name(self) -> str: 111 """ 112 Get the dataset name 113 114 Returns 115 ------- 116 dataset_name : str 117 the name of the dataset record 118 """ 119 120 data_path = correct_path( 121 self.params["data"]["feature_save_path"], self.project_path 122 ) 123 dataset_name = os.path.basename(data_path) 124 return dataset_name 125 126 def split_file(self) -> str: 127 """ 128 Get the split file 129 130 Returns 131 ------- 132 split_path : str 133 the path to the split file 134 """ 135 136 return correct_path(self.params["training"]["split_path"], self.project_path) 137 138 def log_file(self) -> str: 139 """ 140 Get the log file 141 142 Returns 143 ------- 144 log_path : str 145 the path to the log file 146 """ 147 148 return correct_path(self.params["training"]["log_file"], self.project_path) 149 150 def split_info(self) -> Dict: 151 """ 152 Get the train/test/val split information 153 154 Returns 155 ------- 156 split_info : dict 157 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values 158 """ 159 160 val_frac = self.params["training"]["val_frac"] 161 test_frac = self.params["training"]["test_frac"] 162 partition_method = self.params["training"]["partition_method"] 163 return { 164 "val_frac": val_frac, 165 "test_frac": test_frac, 166 "partition_method": partition_method, 167 } 168 169 def same_split_info(self, split_info: Dict) -> bool: 170 """ 171 Check whether this episode has the same split information 172 173 Parameters 174 ---------- 175 split_info : dict 176 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode 177 178 Returns 179 ------- 180 result : bool 181 if True, this episode has the same split information 182 """ 183 184 self_split_info = self.split_info() 185 for k in ["val_frac", "test_frac", "partition_method"]: 186 if self_split_info[k] != split_info[k]: 187 return False 188 return True 189 190 def get_metrics(self) -> List: 191 """ 192 Get a list of tracked metrics 193 194 Returns 195 ------- 196 metrics : list 197 a list of tracked metric names 198 """ 199 200 return self.params["general"]["metric_functions"] 201 202 def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray: 203 """ 204 Get the metric log 205 206 Parameters 207 ---------- 208 mode : {'train', 'val'} 209 the mode to get the log from 210 metric_name : str 211 the metric to get the log for (has to be one of the metric computed for this episode during training) 212 213 Returns 214 ------- 215 log : np.ndarray 216 the log of metric values (empty if the metric was not computed during training) 217 """ 218 219 metric_array = [] 220 with open(self.log_file()) as f: 221 for line in f.readlines(): 222 if mode == "train" and line.startswith("[epoch"): 223 line = line.split("]: ")[1] 224 elif mode == "val" and line.startswith("validation"): 225 line = line.split("validation: ")[1] 226 else: 227 continue 228 metrics = line.split(", ") 229 for metric in metrics: 230 name, value = metric.split() 231 if name == metric_name: 232 metric_array.append(float(value)) 233 return np.array(metric_array) 234 235 def get_epoch_list(self, mode) -> List: 236 """ 237 Get a list of epoch indices 238 239 Returns 240 ------- 241 epoch_list : list 242 a list of int epoch indices 243 """ 244 245 epoch_list = [] 246 with open(self.log_file()) as f: 247 for line in f.readlines(): 248 if line.startswith("[epoch"): 249 epoch = int(line[7:].split("]:")[0]) 250 if mode == "train": 251 epoch_list.append(epoch) 252 elif mode == "val": 253 epoch_list.append(epoch) 254 return epoch_list 255 256 def get_metrics(self) -> List: 257 """ 258 Get a list of metric names in the episode log 259 260 Returns 261 ------- 262 metrics : List 263 a list of string metric names 264 """ 265 266 metrics = [] 267 with open(self.log_file()) as f: 268 for line in f.readlines(): 269 if line.startswith("[epoch"): 270 line = line.split("]: ")[1] 271 elif line.startswith("validation"): 272 line = line.split("validation: ")[1] 273 else: 274 continue 275 metric_logs = line.split(", ") 276 for metric in metric_logs: 277 name, _ = metric.split() 278 metrics.append(name) 279 break 280 return metrics 281 282 def unfinished(self) -> bool: 283 """ 284 Check whether this episode was interrupted 285 286 Returns 287 ------- 288 result : bool 289 True if the number of epochs in the log file is smaller than in the parameters 290 """ 291 292 num_epoch_theor = self.params["training"]["num_epochs"] 293 log_file = self.log_file() 294 if not isinstance(log_file, str): 295 return False 296 if not os.path.exists(log_file): 297 return True 298 with open(self.log_file()) as f: 299 num_epoch = 0 300 val = False 301 for line in f.readlines(): 302 num_epoch += 1 303 if num_epoch == 2 and line.startswith("validation"): 304 val = True 305 if val: 306 num_epoch //= 2 307 return num_epoch < num_epoch_theor 308 309 def get_class_ind(self, class_name: str) -> int: 310 """ 311 Get the integer label from a class name 312 313 Parameters 314 ---------- 315 class_name : str 316 the name of the class 317 318 Returns 319 ------- 320 class_ind : int 321 the integer label 322 """ 323 324 behaviors_dict = self.params["meta"]["behaviors_dict"] 325 for k, v in behaviors_dict.items(): 326 if v == class_name: 327 return k 328 raise ValueError( 329 f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})" 330 ) 331 332 def get_behaviors_dict(self) -> Dict: 333 """ 334 Get behaviors dictionary in the episode 335 336 Returns 337 ------- 338 behaviors_dict : dict 339 a dictionary with class indices as keys and labels as values 340 """ 341 342 return self.params["meta"]["behaviors_dict"] 343 344 def get_num_classes(self) -> int: 345 """ 346 Get number of classes in episode 347 348 Returns 349 ------- 350 num_classes : int 351 the number of classes 352 """ 353 354 return len(self.params["meta"]["behaviors_dict"]) 355 356 357class DecisionThresholds: 358 """ 359 A class that saves and looks up tuned decision thresholds 360 """ 361 362 def __init__(self, path: str) -> None: 363 """ 364 Parameters 365 ---------- 366 path : str 367 the path to the pickled SavedRuns dataframe 368 """ 369 370 self.path = path 371 self.data = pd.read_pickle(path) 372 373 def save_thresholds( 374 self, 375 episode_names: List, 376 epochs: List, 377 metric_name: str, 378 metric_parameters: Dict, 379 thresholds: List, 380 ) -> None: 381 """ 382 Add a new record 383 384 Parameters 385 ---------- 386 episode_name : str 387 the name of the episode 388 epoch : int 389 the epoch index 390 metric_name : str 391 the name of the metric the thresholds were tuned on 392 metric_parameters : dict 393 the metric parameter dictionary 394 thresholds : list 395 a list of float decision thresholds 396 """ 397 398 episodes = set(zip(episode_names, epochs)) 399 for key in ["average", "threshold_value", "ignored_classes"]: 400 if key in metric_parameters: 401 metric_parameters.pop(key) 402 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 403 parameters["thresholds"] = thresholds 404 parameters["episodes"] = episodes 405 pars = {k: [v] for k, v in parameters.items()} 406 self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0) 407 self._save() 408 409 def find_thresholds( 410 self, 411 episode_names: List, 412 epochs: List, 413 metric_name: str, 414 metric_parameters: Dict, 415 ) -> Union[List, None]: 416 """ 417 Find a record 418 419 Parameters 420 ---------- 421 episode_name : str 422 the name of the episode 423 epoch : int 424 the epoch index 425 metric_name : str 426 the name of the metric the thresholds were tuned on 427 metric_parameters : dict 428 the metric parameter dictionary 429 430 Returns 431 ------- 432 thresholds : list 433 a list of float decision thresholds 434 """ 435 436 episodes = set(zip(episode_names, epochs)) 437 for key in ["average", "threshold_value", "ignored_classes"]: 438 if key in metric_parameters: 439 metric_parameters.pop(key) 440 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 441 parameters["episodes"] = episodes 442 filter = deepcopy(parameters) 443 for key, value in parameters.items(): 444 if value is None: 445 filter.pop(key) 446 elif key not in self.data.columns: 447 return None 448 data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)] 449 if len(data) > 0: 450 thresholds = data.iloc[0]["thresholds"] 451 return thresholds 452 else: 453 return None 454 455 def _save(self) -> None: 456 """ 457 Save the records 458 """ 459 460 self.data.copy().to_pickle(self.path) 461 462 463class SavedRuns: 464 """ 465 A class that manages operations with all episode (or prediction) records 466 """ 467 468 def __init__(self, path: str, project_path: str) -> None: 469 """ 470 Parameters 471 ---------- 472 path : str 473 the path to the pickled SavedRuns dataframe 474 """ 475 476 self.path = path 477 self.project_path = project_path 478 self.data = pd.read_pickle(path) 479 480 def update( 481 self, 482 data: pd.DataFrame, 483 data_path: str, 484 annotation_path: str, 485 name_map: Dict = None, 486 force: bool = False, 487 ) -> None: 488 """ 489 Update with new data 490 491 Parameters 492 ---------- 493 data : pd.DataFrame 494 the new dataframe 495 data_path : str 496 the new data path 497 annotation_path : str 498 the new annotation path 499 name_map : dict, optional 500 the name change dictionary; keys are old episode names and values are new episode names 501 force : bool, default False 502 replace existing episodes if `True` 503 """ 504 505 if name_map is None: 506 name_map = {} 507 data = data.rename(index=name_map) 508 for episode in data.index: 509 new_model = os.path.join(self.project_path, "results", "model", episode) 510 data.loc[episode, ("training", "model_save_path")] = new_model 511 new_log = os.path.join( 512 self.project_path, "results", "logs", f"{episode}.txt" 513 ) 514 data.loc[episode, ("training", "log_file")] = new_log 515 old_split = data.loc[episode, ("training", "split_path")] 516 if old_split is None: 517 new_split = None 518 else: 519 new_split = os.path.join( 520 self.project_path, "results", "splits", os.path.basename(old_split) 521 ) 522 data.loc[episode, ("training", "split_path")] = new_split 523 data.loc[episode, ("data", "data_path")] = data_path 524 data.loc[episode, ("data", "annotation_path")] = annotation_path 525 if episode in self.data.index: 526 if force: 527 self.data = self.data.drop(index=[episode]) 528 else: 529 raise RuntimeError(f"The {episode} episode name is already taken!") 530 self.data = pd.concat([self.data, data]) 531 self._save() 532 533 def get_subset(self, episode_names: List) -> pd.DataFrame: 534 """ 535 Get a subset of the raw metadata 536 537 Parameters 538 ---------- 539 episode_names : list 540 a list of the episodes to include 541 """ 542 543 for episode in episode_names: 544 if episode not in self.data.index: 545 raise ValueError( 546 f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records" 547 ) 548 return self.data.loc[episode_names] 549 550 def get_saved_data_path(self, episode_name: str) -> str: 551 """ 552 Get the `saved_data_path` parameter for the episode 553 554 Parameters 555 ---------- 556 episode_name : str 557 the name of the episode 558 559 Returns 560 ------- 561 saved_data_path : str 562 the saved data path 563 """ 564 565 return self.data.loc[episode_name]["data"]["saved_data_path"] 566 567 def check_name_validity(self, episode_name: str) -> bool: 568 """ 569 Check if an episode name already exists 570 571 Parameters 572 ---------- 573 episode_name : str 574 the name to check 575 576 Returns 577 ------- 578 result : bool 579 True if the name can be used 580 """ 581 582 if episode_name in self.data.index: 583 return False 584 else: 585 return True 586 587 def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None: 588 """ 589 Update meta data with evaluation results 590 591 Parameters 592 ---------- 593 episode_name : str 594 the name of the episode to update 595 metrics : dict 596 a dictionary of the metrics 597 """ 598 599 for key, value in metrics.items(): 600 self.data.loc[episode_name, ("results", key)] = value 601 self._save() 602 603 def save_episode( 604 self, 605 episode_name: str, 606 parameters: Dict, 607 behaviors_dict: Dict, 608 suppress_validation: bool = False, 609 training_time: str = None, 610 ) -> None: 611 """ 612 Save a new run record 613 614 Parameters 615 ---------- 616 episode_name : str 617 the name of the episode 618 parameters : dict 619 the parameters to save 620 behaviors_dict : dict 621 the dictionary of behaviors (keys are indices, values are names) 622 suppress_validation : bool, optional False 623 if True, existing episode with the same name will be overwritten 624 training_time : str, optional 625 the training time in '%H:%M:%S' format 626 """ 627 628 if not suppress_validation and episode_name in self.data.index: 629 raise ValueError(f"Episode {episode_name} already exists!") 630 pars = deepcopy(parameters) 631 if "meta" not in pars: 632 pars["meta"] = { 633 "time": strftime("%Y-%m-%d %H:%M:%S", localtime()), 634 "behaviors_dict": behaviors_dict, 635 } 636 else: 637 pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime()) 638 pars["meta"]["behaviors_dict"] = behaviors_dict 639 if training_time is not None: 640 pars["meta"]["training_time"] = training_time 641 if len(parameters.keys()) > 1: 642 pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {}) 643 for metric_name in pars["general"]["metric_functions"]: 644 pars[metric_name] = pars["metrics"].get(metric_name, {}) 645 if pars["general"].get("ssl", None) is not None: 646 for ssl_name in pars["general"]["ssl"]: 647 pars[ssl_name] = pars["ssl"].get(ssl_name, {}) 648 for group_name in ["metrics", "ssl"]: 649 if group_name in pars: 650 pars.pop(group_name) 651 data = { 652 (big_key, small_key): value 653 for big_key, big_value in pars.items() 654 for small_key, value in big_value.items() 655 } 656 list_keys = [] 657 with warnings.catch_warnings(): 658 warnings.filterwarnings("ignore", message="DataFrame is highly fragmented") 659 for k, v in data.items(): 660 if k not in self.data.columns: 661 self.data[k] = np.nan 662 if isinstance(v, list) and not isinstance(v, str): 663 list_keys.append(k) 664 for k in list_keys: 665 self.data[k] = self.data[k].astype(object) 666 self.data.loc[episode_name] = data 667 self._save() 668 669 def load_parameters(self, episode_name: str) -> Dict: 670 """ 671 Load the task parameters from a record 672 673 Parameters 674 ---------- 675 episode_name : str 676 the name of the episode to load 677 678 Returns 679 ------- 680 parameters : dict 681 the loaded task parameters 682 """ 683 684 parameters = defaultdict(lambda: defaultdict(lambda: {})) 685 episode = self.data.loc[episode_name].dropna().to_dict() 686 keys = ["data", "augmentations", "general", "training", "model", "features"] 687 for key in episode: 688 big_key, small_key = key 689 if big_key in keys: 690 parameters[big_key][small_key] = episode[key] 691 # parameters = {k: dict(v) for k, v in parameters.items()} 692 ssl_keys = parameters["general"].get("ssl", None) 693 metric_keys = parameters["general"].get("metric_functions", None) 694 loss_key = parameters["general"]["loss_function"] 695 if ssl_keys is None: 696 ssl_keys = [] 697 if metric_keys is None: 698 metric_keys = [] 699 for key in episode: 700 big_key, small_key = key 701 if big_key in ssl_keys: 702 parameters["ssl"][big_key][small_key] = episode[key] 703 elif big_key in metric_keys: 704 parameters["metrics"][big_key][small_key] = episode[key] 705 elif big_key == "losses": 706 parameters["losses"][loss_key][small_key] = episode[key] 707 parameters = {k: dict(v) for k, v in parameters.items()} 708 parameters["general"]["num_classes"] = Run( 709 episode_name, self.project_path, params=self.data.loc[episode_name] 710 ).get_num_classes() 711 return parameters 712 713 def get_active_datasets(self) -> List: 714 """ 715 Get a list of names of datasets that are used by unfinished episodes 716 717 Returns 718 ------- 719 active_datasets : list 720 a list of dataset names used by unfinished episodes 721 """ 722 723 active_datasets = [] 724 for episode_name in self.unfinished_episodes(): 725 run = Run( 726 episode_name, self.project_path, params=self.data.loc[episode_name] 727 ) 728 active_datasets.append(run.dataset_name()) 729 return active_datasets 730 731 def list_episodes( 732 self, 733 episode_names: List = None, 734 value_filter: str = "", 735 display_parameters: List = None, 736 ) -> pd.DataFrame: 737 """ 738 Get a filtered pandas dataframe with episode metadata 739 740 Parameters 741 ---------- 742 episode_names : List 743 a list of strings of episode names 744 value_filter : str 745 a string of filters to apply of this general structure: 746 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 747 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' 748 display_parameters : List 749 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 750 751 Returns 752 ------- 753 pandas.DataFrame 754 the filtered dataframe 755 """ 756 757 if episode_names is not None: 758 data = deepcopy(self.data.loc[episode_names]) 759 else: 760 data = deepcopy(self.data) 761 if len(data) == 0: 762 return pd.DataFrame() 763 try: 764 filters = value_filter.split(",") 765 if filters == [""]: 766 filters = [] 767 for f in filters: 768 par_name, condition = f.split("::") 769 group_name, par_name = par_name.split("/") 770 sign, value = condition[0], condition[1:] 771 if value[0] == "=": 772 sign += "=" 773 value = value[1:] 774 try: 775 value = float(value) 776 except: 777 if value == "True": 778 value = True 779 elif value == "False": 780 value = False 781 elif value == "None": 782 value = None 783 if value is None: 784 if sign == "=": 785 data = data[data[group_name][par_name].isna()] 786 elif sign == "!=": 787 data = data[~data[group_name][par_name].isna()] 788 elif sign == ">": 789 data = data[data[group_name][par_name] > value] 790 elif sign == ">=": 791 data = data[data[group_name][par_name] >= value] 792 elif sign == "<": 793 data = data[data[group_name][par_name] < value] 794 elif sign == "<=": 795 data = data[data[group_name][par_name] <= value] 796 elif sign == "=": 797 data = data[data[group_name][par_name] == value] 798 elif sign == "!=": 799 data = data[data[group_name][par_name] != value] 800 else: 801 raise ValueError( 802 "Please use one of the signs: [>, <, >=, <=, =, !=]" 803 ) 804 except ValueError: 805 raise ValueError( 806 f"The {value_filter} filter is not valid, please use the following format:" 807 f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', " 808 f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'" 809 ) 810 if display_parameters is not None: 811 if type(display_parameters[0]) is str: 812 display_parameters = [ 813 (x.split("/")[0], x.split("/")[1]) for x in display_parameters 814 ] 815 display_parameters = [x for x in display_parameters if x in data.columns] 816 data = data[display_parameters] 817 return data 818 819 def rename_episode(self, episode_name, new_episode_name): 820 if episode_name in self.data.index and new_episode_name not in self.data.index: 821 self.data.loc[new_episode_name] = self.data.loc[episode_name] 822 model_path = self.data.loc[new_episode_name, ("training", "model_path")] 823 self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join( 824 os.path.dirname(model_path), new_episode_name 825 ) 826 log_path = self.data.loc[new_episode_name, ("training", "log_file")] 827 self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join( 828 os.path.dirname(log_path), f"{new_episode_name}.txt" 829 ) 830 self.data = self.data.drop(index=episode_name) 831 self._save() 832 else: 833 raise ValueError("The names are wrong") 834 835 def remove_episode(self, episode_name: str) -> None: 836 """ 837 Remove all model, logs and metafile records related to an episode 838 839 Parameters 840 ---------- 841 episode_name : str 842 the name of the episode to remove 843 """ 844 845 if episode_name in self.data.index: 846 self.data = self.data.drop(index=episode_name) 847 self._save() 848 849 def unfinished_episodes(self) -> List: 850 """ 851 Get a list of unfinished episodes (currently running or interrupted) 852 853 Returns 854 ------- 855 interrupted_episodes: List 856 a list of string names of unfinished episodes in the records 857 """ 858 859 unfinished = [] 860 for name, params in self.data.iterrows(): 861 if Run(name, project_path=self.project_path, params=params).unfinished(): 862 unfinished.append(name) 863 return unfinished 864 865 def update_episode_results( 866 self, 867 episode_name: str, 868 logs: Tuple, 869 training_time: str = None, 870 ) -> None: 871 """ 872 Add results to an episode record 873 874 Parameters 875 ---------- 876 episode_name : str 877 the name of the episode to update 878 logs : dict 879 a log dictionary from task.train() 880 training_time : str 881 the training time 882 """ 883 884 metrics_log = logs[1] 885 results = {} 886 for key, value in metrics_log["val"].items(): 887 results[("results", key)] = value[-1] 888 if training_time is not None: 889 results[("meta", "training_time")] = training_time 890 for k, v in results.items(): 891 self.data.loc[episode_name, k] = v 892 self._save() 893 894 def get_runs(self, episode_name: str) -> List: 895 """ 896 Get a list of runs with this episode name (episodes like `episode_name::0`) 897 898 Parameters 899 ---------- 900 episode_name : str 901 the name of the episode 902 903 Returns 904 ------- 905 runs_list : List 906 a list of string run names 907 """ 908 909 if episode_name is None: 910 return [] 911 index = self.data.index 912 runs_list = [] 913 for name in index: 914 if name.startswith(episode_name): 915 split = name.split("::") 916 if split[0] == episode_name: 917 if len(split) > 1 and split[-1].isnumeric() or len(split) == 1: 918 runs_list.append(name) 919 elif name == episode_name: 920 runs_list.append(name) 921 return runs_list 922 923 def _save(self): 924 """ 925 Save the dataframe 926 """ 927 928 self.data.copy().to_pickle(self.path) 929 930 931class Searches(SavedRuns): 932 """ 933 A class that manages operations with search records 934 """ 935 936 def save_search( 937 self, 938 search_name: str, 939 parameters: Dict, 940 n_trials: int, 941 best_params: Dict, 942 best_value: float, 943 metric: str, 944 search_space: Dict, 945 ) -> None: 946 """ 947 Save a new search record 948 949 Parameters 950 ---------- 951 search_name : str 952 the name of the search to save 953 parameters : dict 954 the task parameters to save 955 n_trials : int 956 the number of trials in the search 957 best_params : dict 958 the best parameters dictionary 959 best_value : float 960 the best valie 961 metric : str 962 the name of the objective metric 963 search_space : dict 964 a dictionary representing the search space; of this general structure: 965 {'group/param_name': ('float/int/float_log/int_log', start, end), 966 'group/param_name': ('categorical', [choices])}, e.g. 967 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 968 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])} 969 """ 970 971 pars = deepcopy(parameters) 972 pars["results"] = {"best_value": best_value, "best_params": best_params} 973 pars["meta"] = { 974 "objective": metric, 975 "n_trials": n_trials, 976 "search_space": search_space, 977 } 978 self.save_episode(search_name, pars, {}) 979 980 def get_best_params_raw(self, search_name: str) -> Dict: 981 """ 982 Get the raw dictionary of best parameters found by a search 983 984 Parameters 985 ---------- 986 search_name : str 987 the name of the search 988 989 Returns 990 ------- 991 best_params : dict 992 a dictionary of the best parameters where the keys are in '{group}/{name}' format 993 """ 994 995 return self.data.loc[search_name]["results"]["best_params"] 996 997 def get_best_params( 998 self, 999 search_name: str, 1000 load_parameters: List = None, 1001 round_to_binary: List = None, 1002 ) -> Dict: 1003 """ 1004 Get the best parameters from a search 1005 1006 Parameters 1007 ---------- 1008 search_name : str 1009 the name of the search 1010 load_parameters : List, optional 1011 a list of string names of the parameters to load (if not provided all parameters are loaded) 1012 round_to_binary : List, optional 1013 a list of string names of the loaded parameters that should be rounded to the nearest power of two 1014 1015 Returns 1016 ------- 1017 best_params : dict 1018 a dictionary of the best parameters 1019 """ 1020 1021 if round_to_binary is None: 1022 round_to_binary = [] 1023 params = self.data.loc[search_name]["results"]["best_params"] 1024 if load_parameters is not None: 1025 params = {k: v for k, v in params.items() if k in load_parameters} 1026 for par_name in round_to_binary: 1027 if par_name not in params: 1028 continue 1029 if not isinstance(params[par_name], float) and not isinstance( 1030 params[par_name], int 1031 ): 1032 raise TypeError( 1033 f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two" 1034 ) 1035 i = 1 1036 while 2**i < params[par_name]: 1037 i += 1 1038 if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]: 1039 params[par_name] = 2 ** (i - 1) 1040 else: 1041 params[par_name] = 2**i 1042 res = defaultdict(lambda: defaultdict(lambda: {})) 1043 for k, v in params.items(): 1044 big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:]) 1045 if len(small_key.split("/")) == 1: 1046 res[big_key][small_key] = v 1047 else: 1048 group, key = small_key.split("/") 1049 res[big_key][group][key] = v 1050 model = self.data.loc[search_name]["general"]["model_name"] 1051 return res, model 1052 1053 1054class Suggestions(SavedRuns): 1055 def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters): 1056 pars = deepcopy(parameters) 1057 pars["meta"] = meta_parameters 1058 super().save_episode(episode_name, pars, behaviors_dict=None) 1059 1060 1061class SavedStores: 1062 """ 1063 A class that manages operation with saved dataset records 1064 """ 1065 1066 def __init__(self, path): 1067 """ 1068 Parameters 1069 ---------- 1070 path : str 1071 the path to the pickled SavedRuns dataframe 1072 """ 1073 1074 self.path = path 1075 self.data = pd.read_pickle(path) 1076 self.skip_keys = [ 1077 "feature_save_path", 1078 "saved_data_path", 1079 "real_lens", 1080 "recompute_annotation", 1081 ] 1082 1083 def clear(self) -> None: 1084 """ 1085 Remove all datasets 1086 """ 1087 1088 for dataset_name in self.data.index: 1089 self.remove_dataset(dataset_name) 1090 1091 def dataset_names(self) -> List: 1092 """ 1093 Get a list of dataset names 1094 1095 Returns 1096 ------- 1097 dataset_names : List 1098 a list of string dataset names 1099 """ 1100 1101 return list(self.data.index) 1102 1103 def remove(self, names: List) -> None: 1104 """ 1105 Remove some datasets 1106 1107 Parameters 1108 ---------- 1109 names : List 1110 a list of string names of the datasets to delete 1111 """ 1112 1113 for dataset_name in names: 1114 if dataset_name in self.data.index: 1115 self.remove_dataset(dataset_name) 1116 1117 def remove_dataset(self, dataset_name: str) -> None: 1118 """ 1119 Remove a dataset record 1120 1121 Parameters 1122 ---------- 1123 dataset_name : str 1124 the name of the dataset to remove 1125 """ 1126 1127 if dataset_name in self.data.index: 1128 self.data = self.data.drop(index=dataset_name) 1129 self._save() 1130 1131 def find_name(self, parameters: Dict) -> str: 1132 """ 1133 Find a record that satisfies the parameters (if it exists) 1134 1135 Parameters 1136 ---------- 1137 parameters : dict 1138 a dictionary of data parameters 1139 1140 Returns 1141 ------- 1142 name : str 1143 the name of a record that has the same parameters (None if it does not exist; the earliest if there are 1144 several) 1145 """ 1146 1147 filter = deepcopy(parameters) 1148 for key, value in parameters.items(): 1149 if value is None or key in self.skip_keys: 1150 filter.pop(key) 1151 elif key not in self.data.columns: 1152 return None 1153 saved_annotation = self.data[ 1154 (self.data[list(filter)] == pd.Series(filter)).all(axis=1) 1155 ] 1156 for i in range(len(saved_annotation)): 1157 ok = True 1158 for key in saved_annotation.columns: 1159 if key in self.skip_keys: 1160 continue 1161 isnull = pd.isnull(saved_annotation.iloc[i][key]) 1162 if not isinstance(isnull, bool): 1163 isnull = False 1164 if key not in filter and not isnull: 1165 ok = False 1166 if ok: 1167 name = saved_annotation.iloc[i].name 1168 return name 1169 return None 1170 1171 def save_store(self, episode_name: str, parameters: Dict) -> None: 1172 """ 1173 Save a new saved dataset record 1174 1175 Parameters 1176 ---------- 1177 episode_name : str 1178 the name of the dataset 1179 parameters : dict 1180 a dictionary of data parameters 1181 """ 1182 1183 pars = deepcopy(parameters) 1184 for k, v in parameters.items(): 1185 if k not in self.data.columns: 1186 self.data[k] = np.nan 1187 if self.find_name(pars) is None: 1188 self.data.loc[episode_name] = pars 1189 self._save() 1190 1191 def _save(self): 1192 """ 1193 Save the dataframe 1194 """ 1195 1196 self.data.to_pickle(self.path) 1197 1198 def check_name_validity(self, store_name: str) -> bool: 1199 """ 1200 Check if a store name already exists 1201 1202 Parameters 1203 ---------- 1204 episode_name : str 1205 the name to check 1206 1207 Returns 1208 ------- 1209 result : bool 1210 True if the name can be used 1211 """ 1212 1213 if store_name in self.data.index: 1214 return False 1215 else: 1216 return True
22class Run: 23 """ 24 A class that manages operations with a single episode record 25 """ 26 27 def __init__( 28 self, 29 episode_name: str, 30 project_path: str, 31 meta_path: str = None, 32 params: Dict = None, 33 ): 34 """ 35 Parameters 36 ---------- 37 episode_name : str 38 the name of the episode 39 meta_path : str, optional 40 the path to the pickled SavedRuns dataframe 41 params : dict, optional 42 alternative to meta_path: pre-loaded pandas Series of episode parameters 43 """ 44 45 self.name = episode_name 46 self.project_path = project_path 47 if meta_path is not None: 48 try: 49 self.params = pd.read_pickle(meta_path).loc[episode_name] 50 except: 51 raise ValueError(f"The {episode_name} episode does not exist!") 52 elif params is not None: 53 self.params = params 54 else: 55 raise ValueError("Either meta_path or params has to be not None") 56 57 def training_time(self) -> int: 58 """ 59 Get the training time in seconds 60 61 Returns 62 ------- 63 training_time : int 64 the training time in seconds 65 """ 66 67 time_str = self.params["meta"].get("training_time") 68 try: 69 if time_str is None or np.isnan(time_str): 70 return np.nan 71 except TypeError: 72 pass 73 h, m, s = time_str.split(":") 74 seconds = int(h) * 3600 + int(m) * 60 + int(s) 75 return seconds 76 77 def model_file(self, load_epoch: int = None) -> str: 78 """ 79 Get a checkpoint file path 80 81 Parameters 82 ---------- 83 project_path : str 84 the current project folder path 85 load_epoch : int, optional 86 the epoch to load (the closest checkpoint will be chosen; if not given will be set to last) 87 88 Returns 89 ------- 90 checkpoint_path : str 91 the path to the checkpoint 92 """ 93 94 model_path = correct_path( 95 self.params["training"]["model_save_path"], self.project_path 96 ) 97 if load_epoch is None: 98 model_file = sorted(os.listdir(model_path))[-1] 99 else: 100 model_files = os.listdir(model_path) 101 if len(model_files) == 0: 102 model_file = None 103 else: 104 epochs = [int(file[5:].split(".")[0]) for file in model_files] 105 diffs = [np.abs(epoch - load_epoch) for epoch in epochs] 106 argmin = np.argmin(diffs) 107 model_file = model_files[argmin] 108 model_file = os.path.join(model_path, model_file) 109 return model_file 110 111 def dataset_name(self) -> str: 112 """ 113 Get the dataset name 114 115 Returns 116 ------- 117 dataset_name : str 118 the name of the dataset record 119 """ 120 121 data_path = correct_path( 122 self.params["data"]["feature_save_path"], self.project_path 123 ) 124 dataset_name = os.path.basename(data_path) 125 return dataset_name 126 127 def split_file(self) -> str: 128 """ 129 Get the split file 130 131 Returns 132 ------- 133 split_path : str 134 the path to the split file 135 """ 136 137 return correct_path(self.params["training"]["split_path"], self.project_path) 138 139 def log_file(self) -> str: 140 """ 141 Get the log file 142 143 Returns 144 ------- 145 log_path : str 146 the path to the log file 147 """ 148 149 return correct_path(self.params["training"]["log_file"], self.project_path) 150 151 def split_info(self) -> Dict: 152 """ 153 Get the train/test/val split information 154 155 Returns 156 ------- 157 split_info : dict 158 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values 159 """ 160 161 val_frac = self.params["training"]["val_frac"] 162 test_frac = self.params["training"]["test_frac"] 163 partition_method = self.params["training"]["partition_method"] 164 return { 165 "val_frac": val_frac, 166 "test_frac": test_frac, 167 "partition_method": partition_method, 168 } 169 170 def same_split_info(self, split_info: Dict) -> bool: 171 """ 172 Check whether this episode has the same split information 173 174 Parameters 175 ---------- 176 split_info : dict 177 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode 178 179 Returns 180 ------- 181 result : bool 182 if True, this episode has the same split information 183 """ 184 185 self_split_info = self.split_info() 186 for k in ["val_frac", "test_frac", "partition_method"]: 187 if self_split_info[k] != split_info[k]: 188 return False 189 return True 190 191 def get_metrics(self) -> List: 192 """ 193 Get a list of tracked metrics 194 195 Returns 196 ------- 197 metrics : list 198 a list of tracked metric names 199 """ 200 201 return self.params["general"]["metric_functions"] 202 203 def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray: 204 """ 205 Get the metric log 206 207 Parameters 208 ---------- 209 mode : {'train', 'val'} 210 the mode to get the log from 211 metric_name : str 212 the metric to get the log for (has to be one of the metric computed for this episode during training) 213 214 Returns 215 ------- 216 log : np.ndarray 217 the log of metric values (empty if the metric was not computed during training) 218 """ 219 220 metric_array = [] 221 with open(self.log_file()) as f: 222 for line in f.readlines(): 223 if mode == "train" and line.startswith("[epoch"): 224 line = line.split("]: ")[1] 225 elif mode == "val" and line.startswith("validation"): 226 line = line.split("validation: ")[1] 227 else: 228 continue 229 metrics = line.split(", ") 230 for metric in metrics: 231 name, value = metric.split() 232 if name == metric_name: 233 metric_array.append(float(value)) 234 return np.array(metric_array) 235 236 def get_epoch_list(self, mode) -> List: 237 """ 238 Get a list of epoch indices 239 240 Returns 241 ------- 242 epoch_list : list 243 a list of int epoch indices 244 """ 245 246 epoch_list = [] 247 with open(self.log_file()) as f: 248 for line in f.readlines(): 249 if line.startswith("[epoch"): 250 epoch = int(line[7:].split("]:")[0]) 251 if mode == "train": 252 epoch_list.append(epoch) 253 elif mode == "val": 254 epoch_list.append(epoch) 255 return epoch_list 256 257 def get_metrics(self) -> List: 258 """ 259 Get a list of metric names in the episode log 260 261 Returns 262 ------- 263 metrics : List 264 a list of string metric names 265 """ 266 267 metrics = [] 268 with open(self.log_file()) as f: 269 for line in f.readlines(): 270 if line.startswith("[epoch"): 271 line = line.split("]: ")[1] 272 elif line.startswith("validation"): 273 line = line.split("validation: ")[1] 274 else: 275 continue 276 metric_logs = line.split(", ") 277 for metric in metric_logs: 278 name, _ = metric.split() 279 metrics.append(name) 280 break 281 return metrics 282 283 def unfinished(self) -> bool: 284 """ 285 Check whether this episode was interrupted 286 287 Returns 288 ------- 289 result : bool 290 True if the number of epochs in the log file is smaller than in the parameters 291 """ 292 293 num_epoch_theor = self.params["training"]["num_epochs"] 294 log_file = self.log_file() 295 if not isinstance(log_file, str): 296 return False 297 if not os.path.exists(log_file): 298 return True 299 with open(self.log_file()) as f: 300 num_epoch = 0 301 val = False 302 for line in f.readlines(): 303 num_epoch += 1 304 if num_epoch == 2 and line.startswith("validation"): 305 val = True 306 if val: 307 num_epoch //= 2 308 return num_epoch < num_epoch_theor 309 310 def get_class_ind(self, class_name: str) -> int: 311 """ 312 Get the integer label from a class name 313 314 Parameters 315 ---------- 316 class_name : str 317 the name of the class 318 319 Returns 320 ------- 321 class_ind : int 322 the integer label 323 """ 324 325 behaviors_dict = self.params["meta"]["behaviors_dict"] 326 for k, v in behaviors_dict.items(): 327 if v == class_name: 328 return k 329 raise ValueError( 330 f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})" 331 ) 332 333 def get_behaviors_dict(self) -> Dict: 334 """ 335 Get behaviors dictionary in the episode 336 337 Returns 338 ------- 339 behaviors_dict : dict 340 a dictionary with class indices as keys and labels as values 341 """ 342 343 return self.params["meta"]["behaviors_dict"] 344 345 def get_num_classes(self) -> int: 346 """ 347 Get number of classes in episode 348 349 Returns 350 ------- 351 num_classes : int 352 the number of classes 353 """ 354 355 return len(self.params["meta"]["behaviors_dict"])
A class that manages operations with a single episode record
27 def __init__( 28 self, 29 episode_name: str, 30 project_path: str, 31 meta_path: str = None, 32 params: Dict = None, 33 ): 34 """ 35 Parameters 36 ---------- 37 episode_name : str 38 the name of the episode 39 meta_path : str, optional 40 the path to the pickled SavedRuns dataframe 41 params : dict, optional 42 alternative to meta_path: pre-loaded pandas Series of episode parameters 43 """ 44 45 self.name = episode_name 46 self.project_path = project_path 47 if meta_path is not None: 48 try: 49 self.params = pd.read_pickle(meta_path).loc[episode_name] 50 except: 51 raise ValueError(f"The {episode_name} episode does not exist!") 52 elif params is not None: 53 self.params = params 54 else: 55 raise ValueError("Either meta_path or params has to be not None")
Parameters
episode_name : str the name of the episode meta_path : str, optional the path to the pickled SavedRuns dataframe params : dict, optional alternative to meta_path: pre-loaded pandas Series of episode parameters
57 def training_time(self) -> int: 58 """ 59 Get the training time in seconds 60 61 Returns 62 ------- 63 training_time : int 64 the training time in seconds 65 """ 66 67 time_str = self.params["meta"].get("training_time") 68 try: 69 if time_str is None or np.isnan(time_str): 70 return np.nan 71 except TypeError: 72 pass 73 h, m, s = time_str.split(":") 74 seconds = int(h) * 3600 + int(m) * 60 + int(s) 75 return seconds
Get the training time in seconds
Returns
training_time : int the training time in seconds
77 def model_file(self, load_epoch: int = None) -> str: 78 """ 79 Get a checkpoint file path 80 81 Parameters 82 ---------- 83 project_path : str 84 the current project folder path 85 load_epoch : int, optional 86 the epoch to load (the closest checkpoint will be chosen; if not given will be set to last) 87 88 Returns 89 ------- 90 checkpoint_path : str 91 the path to the checkpoint 92 """ 93 94 model_path = correct_path( 95 self.params["training"]["model_save_path"], self.project_path 96 ) 97 if load_epoch is None: 98 model_file = sorted(os.listdir(model_path))[-1] 99 else: 100 model_files = os.listdir(model_path) 101 if len(model_files) == 0: 102 model_file = None 103 else: 104 epochs = [int(file[5:].split(".")[0]) for file in model_files] 105 diffs = [np.abs(epoch - load_epoch) for epoch in epochs] 106 argmin = np.argmin(diffs) 107 model_file = model_files[argmin] 108 model_file = os.path.join(model_path, model_file) 109 return model_file
Get a checkpoint file path
Parameters
project_path : str the current project folder path load_epoch : int, optional the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)
Returns
checkpoint_path : str the path to the checkpoint
111 def dataset_name(self) -> str: 112 """ 113 Get the dataset name 114 115 Returns 116 ------- 117 dataset_name : str 118 the name of the dataset record 119 """ 120 121 data_path = correct_path( 122 self.params["data"]["feature_save_path"], self.project_path 123 ) 124 dataset_name = os.path.basename(data_path) 125 return dataset_name
Get the dataset name
Returns
dataset_name : str the name of the dataset record
127 def split_file(self) -> str: 128 """ 129 Get the split file 130 131 Returns 132 ------- 133 split_path : str 134 the path to the split file 135 """ 136 137 return correct_path(self.params["training"]["split_path"], self.project_path)
Get the split file
Returns
split_path : str the path to the split file
139 def log_file(self) -> str: 140 """ 141 Get the log file 142 143 Returns 144 ------- 145 log_path : str 146 the path to the log file 147 """ 148 149 return correct_path(self.params["training"]["log_file"], self.project_path)
Get the log file
Returns
log_path : str the path to the log file
151 def split_info(self) -> Dict: 152 """ 153 Get the train/test/val split information 154 155 Returns 156 ------- 157 split_info : dict 158 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values 159 """ 160 161 val_frac = self.params["training"]["val_frac"] 162 test_frac = self.params["training"]["test_frac"] 163 partition_method = self.params["training"]["partition_method"] 164 return { 165 "val_frac": val_frac, 166 "test_frac": test_frac, 167 "partition_method": partition_method, 168 }
Get the train/test/val split information
Returns
split_info : dict a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values
170 def same_split_info(self, split_info: Dict) -> bool: 171 """ 172 Check whether this episode has the same split information 173 174 Parameters 175 ---------- 176 split_info : dict 177 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode 178 179 Returns 180 ------- 181 result : bool 182 if True, this episode has the same split information 183 """ 184 185 self_split_info = self.split_info() 186 for k in ["val_frac", "test_frac", "partition_method"]: 187 if self_split_info[k] != split_info[k]: 188 return False 189 return True
Check whether this episode has the same split information
Parameters
split_info : dict a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode
Returns
result : bool if True, this episode has the same split information
257 def get_metrics(self) -> List: 258 """ 259 Get a list of metric names in the episode log 260 261 Returns 262 ------- 263 metrics : List 264 a list of string metric names 265 """ 266 267 metrics = [] 268 with open(self.log_file()) as f: 269 for line in f.readlines(): 270 if line.startswith("[epoch"): 271 line = line.split("]: ")[1] 272 elif line.startswith("validation"): 273 line = line.split("validation: ")[1] 274 else: 275 continue 276 metric_logs = line.split(", ") 277 for metric in metric_logs: 278 name, _ = metric.split() 279 metrics.append(name) 280 break 281 return metrics
Get a list of metric names in the episode log
Returns
metrics : List a list of string metric names
203 def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray: 204 """ 205 Get the metric log 206 207 Parameters 208 ---------- 209 mode : {'train', 'val'} 210 the mode to get the log from 211 metric_name : str 212 the metric to get the log for (has to be one of the metric computed for this episode during training) 213 214 Returns 215 ------- 216 log : np.ndarray 217 the log of metric values (empty if the metric was not computed during training) 218 """ 219 220 metric_array = [] 221 with open(self.log_file()) as f: 222 for line in f.readlines(): 223 if mode == "train" and line.startswith("[epoch"): 224 line = line.split("]: ")[1] 225 elif mode == "val" and line.startswith("validation"): 226 line = line.split("validation: ")[1] 227 else: 228 continue 229 metrics = line.split(", ") 230 for metric in metrics: 231 name, value = metric.split() 232 if name == metric_name: 233 metric_array.append(float(value)) 234 return np.array(metric_array)
Get the metric log
Parameters
mode : {'train', 'val'} the mode to get the log from metric_name : str the metric to get the log for (has to be one of the metric computed for this episode during training)
Returns
log : np.ndarray the log of metric values (empty if the metric was not computed during training)
236 def get_epoch_list(self, mode) -> List: 237 """ 238 Get a list of epoch indices 239 240 Returns 241 ------- 242 epoch_list : list 243 a list of int epoch indices 244 """ 245 246 epoch_list = [] 247 with open(self.log_file()) as f: 248 for line in f.readlines(): 249 if line.startswith("[epoch"): 250 epoch = int(line[7:].split("]:")[0]) 251 if mode == "train": 252 epoch_list.append(epoch) 253 elif mode == "val": 254 epoch_list.append(epoch) 255 return epoch_list
Get a list of epoch indices
Returns
epoch_list : list a list of int epoch indices
283 def unfinished(self) -> bool: 284 """ 285 Check whether this episode was interrupted 286 287 Returns 288 ------- 289 result : bool 290 True if the number of epochs in the log file is smaller than in the parameters 291 """ 292 293 num_epoch_theor = self.params["training"]["num_epochs"] 294 log_file = self.log_file() 295 if not isinstance(log_file, str): 296 return False 297 if not os.path.exists(log_file): 298 return True 299 with open(self.log_file()) as f: 300 num_epoch = 0 301 val = False 302 for line in f.readlines(): 303 num_epoch += 1 304 if num_epoch == 2 and line.startswith("validation"): 305 val = True 306 if val: 307 num_epoch //= 2 308 return num_epoch < num_epoch_theor
Check whether this episode was interrupted
Returns
result : bool True if the number of epochs in the log file is smaller than in the parameters
310 def get_class_ind(self, class_name: str) -> int: 311 """ 312 Get the integer label from a class name 313 314 Parameters 315 ---------- 316 class_name : str 317 the name of the class 318 319 Returns 320 ------- 321 class_ind : int 322 the integer label 323 """ 324 325 behaviors_dict = self.params["meta"]["behaviors_dict"] 326 for k, v in behaviors_dict.items(): 327 if v == class_name: 328 return k 329 raise ValueError( 330 f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})" 331 )
Get the integer label from a class name
Parameters
class_name : str the name of the class
Returns
class_ind : int the integer label
333 def get_behaviors_dict(self) -> Dict: 334 """ 335 Get behaviors dictionary in the episode 336 337 Returns 338 ------- 339 behaviors_dict : dict 340 a dictionary with class indices as keys and labels as values 341 """ 342 343 return self.params["meta"]["behaviors_dict"]
Get behaviors dictionary in the episode
Returns
behaviors_dict : dict a dictionary with class indices as keys and labels as values
345 def get_num_classes(self) -> int: 346 """ 347 Get number of classes in episode 348 349 Returns 350 ------- 351 num_classes : int 352 the number of classes 353 """ 354 355 return len(self.params["meta"]["behaviors_dict"])
Get number of classes in episode
Returns
num_classes : int the number of classes
358class DecisionThresholds: 359 """ 360 A class that saves and looks up tuned decision thresholds 361 """ 362 363 def __init__(self, path: str) -> None: 364 """ 365 Parameters 366 ---------- 367 path : str 368 the path to the pickled SavedRuns dataframe 369 """ 370 371 self.path = path 372 self.data = pd.read_pickle(path) 373 374 def save_thresholds( 375 self, 376 episode_names: List, 377 epochs: List, 378 metric_name: str, 379 metric_parameters: Dict, 380 thresholds: List, 381 ) -> None: 382 """ 383 Add a new record 384 385 Parameters 386 ---------- 387 episode_name : str 388 the name of the episode 389 epoch : int 390 the epoch index 391 metric_name : str 392 the name of the metric the thresholds were tuned on 393 metric_parameters : dict 394 the metric parameter dictionary 395 thresholds : list 396 a list of float decision thresholds 397 """ 398 399 episodes = set(zip(episode_names, epochs)) 400 for key in ["average", "threshold_value", "ignored_classes"]: 401 if key in metric_parameters: 402 metric_parameters.pop(key) 403 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 404 parameters["thresholds"] = thresholds 405 parameters["episodes"] = episodes 406 pars = {k: [v] for k, v in parameters.items()} 407 self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0) 408 self._save() 409 410 def find_thresholds( 411 self, 412 episode_names: List, 413 epochs: List, 414 metric_name: str, 415 metric_parameters: Dict, 416 ) -> Union[List, None]: 417 """ 418 Find a record 419 420 Parameters 421 ---------- 422 episode_name : str 423 the name of the episode 424 epoch : int 425 the epoch index 426 metric_name : str 427 the name of the metric the thresholds were tuned on 428 metric_parameters : dict 429 the metric parameter dictionary 430 431 Returns 432 ------- 433 thresholds : list 434 a list of float decision thresholds 435 """ 436 437 episodes = set(zip(episode_names, epochs)) 438 for key in ["average", "threshold_value", "ignored_classes"]: 439 if key in metric_parameters: 440 metric_parameters.pop(key) 441 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 442 parameters["episodes"] = episodes 443 filter = deepcopy(parameters) 444 for key, value in parameters.items(): 445 if value is None: 446 filter.pop(key) 447 elif key not in self.data.columns: 448 return None 449 data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)] 450 if len(data) > 0: 451 thresholds = data.iloc[0]["thresholds"] 452 return thresholds 453 else: 454 return None 455 456 def _save(self) -> None: 457 """ 458 Save the records 459 """ 460 461 self.data.copy().to_pickle(self.path)
A class that saves and looks up tuned decision thresholds
363 def __init__(self, path: str) -> None: 364 """ 365 Parameters 366 ---------- 367 path : str 368 the path to the pickled SavedRuns dataframe 369 """ 370 371 self.path = path 372 self.data = pd.read_pickle(path)
Parameters
path : str the path to the pickled SavedRuns dataframe
374 def save_thresholds( 375 self, 376 episode_names: List, 377 epochs: List, 378 metric_name: str, 379 metric_parameters: Dict, 380 thresholds: List, 381 ) -> None: 382 """ 383 Add a new record 384 385 Parameters 386 ---------- 387 episode_name : str 388 the name of the episode 389 epoch : int 390 the epoch index 391 metric_name : str 392 the name of the metric the thresholds were tuned on 393 metric_parameters : dict 394 the metric parameter dictionary 395 thresholds : list 396 a list of float decision thresholds 397 """ 398 399 episodes = set(zip(episode_names, epochs)) 400 for key in ["average", "threshold_value", "ignored_classes"]: 401 if key in metric_parameters: 402 metric_parameters.pop(key) 403 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 404 parameters["thresholds"] = thresholds 405 parameters["episodes"] = episodes 406 pars = {k: [v] for k, v in parameters.items()} 407 self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0) 408 self._save()
Add a new record
Parameters
episode_name : str the name of the episode epoch : int the epoch index metric_name : str the name of the metric the thresholds were tuned on metric_parameters : dict the metric parameter dictionary thresholds : list a list of float decision thresholds
410 def find_thresholds( 411 self, 412 episode_names: List, 413 epochs: List, 414 metric_name: str, 415 metric_parameters: Dict, 416 ) -> Union[List, None]: 417 """ 418 Find a record 419 420 Parameters 421 ---------- 422 episode_name : str 423 the name of the episode 424 epoch : int 425 the epoch index 426 metric_name : str 427 the name of the metric the thresholds were tuned on 428 metric_parameters : dict 429 the metric parameter dictionary 430 431 Returns 432 ------- 433 thresholds : list 434 a list of float decision thresholds 435 """ 436 437 episodes = set(zip(episode_names, epochs)) 438 for key in ["average", "threshold_value", "ignored_classes"]: 439 if key in metric_parameters: 440 metric_parameters.pop(key) 441 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 442 parameters["episodes"] = episodes 443 filter = deepcopy(parameters) 444 for key, value in parameters.items(): 445 if value is None: 446 filter.pop(key) 447 elif key not in self.data.columns: 448 return None 449 data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)] 450 if len(data) > 0: 451 thresholds = data.iloc[0]["thresholds"] 452 return thresholds 453 else: 454 return None
Find a record
Parameters
episode_name : str the name of the episode epoch : int the epoch index metric_name : str the name of the metric the thresholds were tuned on metric_parameters : dict the metric parameter dictionary
Returns
thresholds : list a list of float decision thresholds
464class SavedRuns: 465 """ 466 A class that manages operations with all episode (or prediction) records 467 """ 468 469 def __init__(self, path: str, project_path: str) -> None: 470 """ 471 Parameters 472 ---------- 473 path : str 474 the path to the pickled SavedRuns dataframe 475 """ 476 477 self.path = path 478 self.project_path = project_path 479 self.data = pd.read_pickle(path) 480 481 def update( 482 self, 483 data: pd.DataFrame, 484 data_path: str, 485 annotation_path: str, 486 name_map: Dict = None, 487 force: bool = False, 488 ) -> None: 489 """ 490 Update with new data 491 492 Parameters 493 ---------- 494 data : pd.DataFrame 495 the new dataframe 496 data_path : str 497 the new data path 498 annotation_path : str 499 the new annotation path 500 name_map : dict, optional 501 the name change dictionary; keys are old episode names and values are new episode names 502 force : bool, default False 503 replace existing episodes if `True` 504 """ 505 506 if name_map is None: 507 name_map = {} 508 data = data.rename(index=name_map) 509 for episode in data.index: 510 new_model = os.path.join(self.project_path, "results", "model", episode) 511 data.loc[episode, ("training", "model_save_path")] = new_model 512 new_log = os.path.join( 513 self.project_path, "results", "logs", f"{episode}.txt" 514 ) 515 data.loc[episode, ("training", "log_file")] = new_log 516 old_split = data.loc[episode, ("training", "split_path")] 517 if old_split is None: 518 new_split = None 519 else: 520 new_split = os.path.join( 521 self.project_path, "results", "splits", os.path.basename(old_split) 522 ) 523 data.loc[episode, ("training", "split_path")] = new_split 524 data.loc[episode, ("data", "data_path")] = data_path 525 data.loc[episode, ("data", "annotation_path")] = annotation_path 526 if episode in self.data.index: 527 if force: 528 self.data = self.data.drop(index=[episode]) 529 else: 530 raise RuntimeError(f"The {episode} episode name is already taken!") 531 self.data = pd.concat([self.data, data]) 532 self._save() 533 534 def get_subset(self, episode_names: List) -> pd.DataFrame: 535 """ 536 Get a subset of the raw metadata 537 538 Parameters 539 ---------- 540 episode_names : list 541 a list of the episodes to include 542 """ 543 544 for episode in episode_names: 545 if episode not in self.data.index: 546 raise ValueError( 547 f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records" 548 ) 549 return self.data.loc[episode_names] 550 551 def get_saved_data_path(self, episode_name: str) -> str: 552 """ 553 Get the `saved_data_path` parameter for the episode 554 555 Parameters 556 ---------- 557 episode_name : str 558 the name of the episode 559 560 Returns 561 ------- 562 saved_data_path : str 563 the saved data path 564 """ 565 566 return self.data.loc[episode_name]["data"]["saved_data_path"] 567 568 def check_name_validity(self, episode_name: str) -> bool: 569 """ 570 Check if an episode name already exists 571 572 Parameters 573 ---------- 574 episode_name : str 575 the name to check 576 577 Returns 578 ------- 579 result : bool 580 True if the name can be used 581 """ 582 583 if episode_name in self.data.index: 584 return False 585 else: 586 return True 587 588 def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None: 589 """ 590 Update meta data with evaluation results 591 592 Parameters 593 ---------- 594 episode_name : str 595 the name of the episode to update 596 metrics : dict 597 a dictionary of the metrics 598 """ 599 600 for key, value in metrics.items(): 601 self.data.loc[episode_name, ("results", key)] = value 602 self._save() 603 604 def save_episode( 605 self, 606 episode_name: str, 607 parameters: Dict, 608 behaviors_dict: Dict, 609 suppress_validation: bool = False, 610 training_time: str = None, 611 ) -> None: 612 """ 613 Save a new run record 614 615 Parameters 616 ---------- 617 episode_name : str 618 the name of the episode 619 parameters : dict 620 the parameters to save 621 behaviors_dict : dict 622 the dictionary of behaviors (keys are indices, values are names) 623 suppress_validation : bool, optional False 624 if True, existing episode with the same name will be overwritten 625 training_time : str, optional 626 the training time in '%H:%M:%S' format 627 """ 628 629 if not suppress_validation and episode_name in self.data.index: 630 raise ValueError(f"Episode {episode_name} already exists!") 631 pars = deepcopy(parameters) 632 if "meta" not in pars: 633 pars["meta"] = { 634 "time": strftime("%Y-%m-%d %H:%M:%S", localtime()), 635 "behaviors_dict": behaviors_dict, 636 } 637 else: 638 pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime()) 639 pars["meta"]["behaviors_dict"] = behaviors_dict 640 if training_time is not None: 641 pars["meta"]["training_time"] = training_time 642 if len(parameters.keys()) > 1: 643 pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {}) 644 for metric_name in pars["general"]["metric_functions"]: 645 pars[metric_name] = pars["metrics"].get(metric_name, {}) 646 if pars["general"].get("ssl", None) is not None: 647 for ssl_name in pars["general"]["ssl"]: 648 pars[ssl_name] = pars["ssl"].get(ssl_name, {}) 649 for group_name in ["metrics", "ssl"]: 650 if group_name in pars: 651 pars.pop(group_name) 652 data = { 653 (big_key, small_key): value 654 for big_key, big_value in pars.items() 655 for small_key, value in big_value.items() 656 } 657 list_keys = [] 658 with warnings.catch_warnings(): 659 warnings.filterwarnings("ignore", message="DataFrame is highly fragmented") 660 for k, v in data.items(): 661 if k not in self.data.columns: 662 self.data[k] = np.nan 663 if isinstance(v, list) and not isinstance(v, str): 664 list_keys.append(k) 665 for k in list_keys: 666 self.data[k] = self.data[k].astype(object) 667 self.data.loc[episode_name] = data 668 self._save() 669 670 def load_parameters(self, episode_name: str) -> Dict: 671 """ 672 Load the task parameters from a record 673 674 Parameters 675 ---------- 676 episode_name : str 677 the name of the episode to load 678 679 Returns 680 ------- 681 parameters : dict 682 the loaded task parameters 683 """ 684 685 parameters = defaultdict(lambda: defaultdict(lambda: {})) 686 episode = self.data.loc[episode_name].dropna().to_dict() 687 keys = ["data", "augmentations", "general", "training", "model", "features"] 688 for key in episode: 689 big_key, small_key = key 690 if big_key in keys: 691 parameters[big_key][small_key] = episode[key] 692 # parameters = {k: dict(v) for k, v in parameters.items()} 693 ssl_keys = parameters["general"].get("ssl", None) 694 metric_keys = parameters["general"].get("metric_functions", None) 695 loss_key = parameters["general"]["loss_function"] 696 if ssl_keys is None: 697 ssl_keys = [] 698 if metric_keys is None: 699 metric_keys = [] 700 for key in episode: 701 big_key, small_key = key 702 if big_key in ssl_keys: 703 parameters["ssl"][big_key][small_key] = episode[key] 704 elif big_key in metric_keys: 705 parameters["metrics"][big_key][small_key] = episode[key] 706 elif big_key == "losses": 707 parameters["losses"][loss_key][small_key] = episode[key] 708 parameters = {k: dict(v) for k, v in parameters.items()} 709 parameters["general"]["num_classes"] = Run( 710 episode_name, self.project_path, params=self.data.loc[episode_name] 711 ).get_num_classes() 712 return parameters 713 714 def get_active_datasets(self) -> List: 715 """ 716 Get a list of names of datasets that are used by unfinished episodes 717 718 Returns 719 ------- 720 active_datasets : list 721 a list of dataset names used by unfinished episodes 722 """ 723 724 active_datasets = [] 725 for episode_name in self.unfinished_episodes(): 726 run = Run( 727 episode_name, self.project_path, params=self.data.loc[episode_name] 728 ) 729 active_datasets.append(run.dataset_name()) 730 return active_datasets 731 732 def list_episodes( 733 self, 734 episode_names: List = None, 735 value_filter: str = "", 736 display_parameters: List = None, 737 ) -> pd.DataFrame: 738 """ 739 Get a filtered pandas dataframe with episode metadata 740 741 Parameters 742 ---------- 743 episode_names : List 744 a list of strings of episode names 745 value_filter : str 746 a string of filters to apply of this general structure: 747 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 748 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' 749 display_parameters : List 750 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 751 752 Returns 753 ------- 754 pandas.DataFrame 755 the filtered dataframe 756 """ 757 758 if episode_names is not None: 759 data = deepcopy(self.data.loc[episode_names]) 760 else: 761 data = deepcopy(self.data) 762 if len(data) == 0: 763 return pd.DataFrame() 764 try: 765 filters = value_filter.split(",") 766 if filters == [""]: 767 filters = [] 768 for f in filters: 769 par_name, condition = f.split("::") 770 group_name, par_name = par_name.split("/") 771 sign, value = condition[0], condition[1:] 772 if value[0] == "=": 773 sign += "=" 774 value = value[1:] 775 try: 776 value = float(value) 777 except: 778 if value == "True": 779 value = True 780 elif value == "False": 781 value = False 782 elif value == "None": 783 value = None 784 if value is None: 785 if sign == "=": 786 data = data[data[group_name][par_name].isna()] 787 elif sign == "!=": 788 data = data[~data[group_name][par_name].isna()] 789 elif sign == ">": 790 data = data[data[group_name][par_name] > value] 791 elif sign == ">=": 792 data = data[data[group_name][par_name] >= value] 793 elif sign == "<": 794 data = data[data[group_name][par_name] < value] 795 elif sign == "<=": 796 data = data[data[group_name][par_name] <= value] 797 elif sign == "=": 798 data = data[data[group_name][par_name] == value] 799 elif sign == "!=": 800 data = data[data[group_name][par_name] != value] 801 else: 802 raise ValueError( 803 "Please use one of the signs: [>, <, >=, <=, =, !=]" 804 ) 805 except ValueError: 806 raise ValueError( 807 f"The {value_filter} filter is not valid, please use the following format:" 808 f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', " 809 f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'" 810 ) 811 if display_parameters is not None: 812 if type(display_parameters[0]) is str: 813 display_parameters = [ 814 (x.split("/")[0], x.split("/")[1]) for x in display_parameters 815 ] 816 display_parameters = [x for x in display_parameters if x in data.columns] 817 data = data[display_parameters] 818 return data 819 820 def rename_episode(self, episode_name, new_episode_name): 821 if episode_name in self.data.index and new_episode_name not in self.data.index: 822 self.data.loc[new_episode_name] = self.data.loc[episode_name] 823 model_path = self.data.loc[new_episode_name, ("training", "model_path")] 824 self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join( 825 os.path.dirname(model_path), new_episode_name 826 ) 827 log_path = self.data.loc[new_episode_name, ("training", "log_file")] 828 self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join( 829 os.path.dirname(log_path), f"{new_episode_name}.txt" 830 ) 831 self.data = self.data.drop(index=episode_name) 832 self._save() 833 else: 834 raise ValueError("The names are wrong") 835 836 def remove_episode(self, episode_name: str) -> None: 837 """ 838 Remove all model, logs and metafile records related to an episode 839 840 Parameters 841 ---------- 842 episode_name : str 843 the name of the episode to remove 844 """ 845 846 if episode_name in self.data.index: 847 self.data = self.data.drop(index=episode_name) 848 self._save() 849 850 def unfinished_episodes(self) -> List: 851 """ 852 Get a list of unfinished episodes (currently running or interrupted) 853 854 Returns 855 ------- 856 interrupted_episodes: List 857 a list of string names of unfinished episodes in the records 858 """ 859 860 unfinished = [] 861 for name, params in self.data.iterrows(): 862 if Run(name, project_path=self.project_path, params=params).unfinished(): 863 unfinished.append(name) 864 return unfinished 865 866 def update_episode_results( 867 self, 868 episode_name: str, 869 logs: Tuple, 870 training_time: str = None, 871 ) -> None: 872 """ 873 Add results to an episode record 874 875 Parameters 876 ---------- 877 episode_name : str 878 the name of the episode to update 879 logs : dict 880 a log dictionary from task.train() 881 training_time : str 882 the training time 883 """ 884 885 metrics_log = logs[1] 886 results = {} 887 for key, value in metrics_log["val"].items(): 888 results[("results", key)] = value[-1] 889 if training_time is not None: 890 results[("meta", "training_time")] = training_time 891 for k, v in results.items(): 892 self.data.loc[episode_name, k] = v 893 self._save() 894 895 def get_runs(self, episode_name: str) -> List: 896 """ 897 Get a list of runs with this episode name (episodes like `episode_name::0`) 898 899 Parameters 900 ---------- 901 episode_name : str 902 the name of the episode 903 904 Returns 905 ------- 906 runs_list : List 907 a list of string run names 908 """ 909 910 if episode_name is None: 911 return [] 912 index = self.data.index 913 runs_list = [] 914 for name in index: 915 if name.startswith(episode_name): 916 split = name.split("::") 917 if split[0] == episode_name: 918 if len(split) > 1 and split[-1].isnumeric() or len(split) == 1: 919 runs_list.append(name) 920 elif name == episode_name: 921 runs_list.append(name) 922 return runs_list 923 924 def _save(self): 925 """ 926 Save the dataframe 927 """ 928 929 self.data.copy().to_pickle(self.path)
A class that manages operations with all episode (or prediction) records
469 def __init__(self, path: str, project_path: str) -> None: 470 """ 471 Parameters 472 ---------- 473 path : str 474 the path to the pickled SavedRuns dataframe 475 """ 476 477 self.path = path 478 self.project_path = project_path 479 self.data = pd.read_pickle(path)
Parameters
path : str the path to the pickled SavedRuns dataframe
481 def update( 482 self, 483 data: pd.DataFrame, 484 data_path: str, 485 annotation_path: str, 486 name_map: Dict = None, 487 force: bool = False, 488 ) -> None: 489 """ 490 Update with new data 491 492 Parameters 493 ---------- 494 data : pd.DataFrame 495 the new dataframe 496 data_path : str 497 the new data path 498 annotation_path : str 499 the new annotation path 500 name_map : dict, optional 501 the name change dictionary; keys are old episode names and values are new episode names 502 force : bool, default False 503 replace existing episodes if `True` 504 """ 505 506 if name_map is None: 507 name_map = {} 508 data = data.rename(index=name_map) 509 for episode in data.index: 510 new_model = os.path.join(self.project_path, "results", "model", episode) 511 data.loc[episode, ("training", "model_save_path")] = new_model 512 new_log = os.path.join( 513 self.project_path, "results", "logs", f"{episode}.txt" 514 ) 515 data.loc[episode, ("training", "log_file")] = new_log 516 old_split = data.loc[episode, ("training", "split_path")] 517 if old_split is None: 518 new_split = None 519 else: 520 new_split = os.path.join( 521 self.project_path, "results", "splits", os.path.basename(old_split) 522 ) 523 data.loc[episode, ("training", "split_path")] = new_split 524 data.loc[episode, ("data", "data_path")] = data_path 525 data.loc[episode, ("data", "annotation_path")] = annotation_path 526 if episode in self.data.index: 527 if force: 528 self.data = self.data.drop(index=[episode]) 529 else: 530 raise RuntimeError(f"The {episode} episode name is already taken!") 531 self.data = pd.concat([self.data, data]) 532 self._save()
Update with new data
Parameters
data : pd.DataFrame
the new dataframe
data_path : str
the new data path
annotation_path : str
the new annotation path
name_map : dict, optional
the name change dictionary; keys are old episode names and values are new episode names
force : bool, default False
replace existing episodes if True
534 def get_subset(self, episode_names: List) -> pd.DataFrame: 535 """ 536 Get a subset of the raw metadata 537 538 Parameters 539 ---------- 540 episode_names : list 541 a list of the episodes to include 542 """ 543 544 for episode in episode_names: 545 if episode not in self.data.index: 546 raise ValueError( 547 f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records" 548 ) 549 return self.data.loc[episode_names]
Get a subset of the raw metadata
Parameters
episode_names : list a list of the episodes to include
551 def get_saved_data_path(self, episode_name: str) -> str: 552 """ 553 Get the `saved_data_path` parameter for the episode 554 555 Parameters 556 ---------- 557 episode_name : str 558 the name of the episode 559 560 Returns 561 ------- 562 saved_data_path : str 563 the saved data path 564 """ 565 566 return self.data.loc[episode_name]["data"]["saved_data_path"]
Get the saved_data_path
parameter for the episode
Parameters
episode_name : str the name of the episode
Returns
saved_data_path : str the saved data path
568 def check_name_validity(self, episode_name: str) -> bool: 569 """ 570 Check if an episode name already exists 571 572 Parameters 573 ---------- 574 episode_name : str 575 the name to check 576 577 Returns 578 ------- 579 result : bool 580 True if the name can be used 581 """ 582 583 if episode_name in self.data.index: 584 return False 585 else: 586 return True
Check if an episode name already exists
Parameters
episode_name : str the name to check
Returns
result : bool True if the name can be used
588 def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None: 589 """ 590 Update meta data with evaluation results 591 592 Parameters 593 ---------- 594 episode_name : str 595 the name of the episode to update 596 metrics : dict 597 a dictionary of the metrics 598 """ 599 600 for key, value in metrics.items(): 601 self.data.loc[episode_name, ("results", key)] = value 602 self._save()
Update meta data with evaluation results
Parameters
episode_name : str the name of the episode to update metrics : dict a dictionary of the metrics
604 def save_episode( 605 self, 606 episode_name: str, 607 parameters: Dict, 608 behaviors_dict: Dict, 609 suppress_validation: bool = False, 610 training_time: str = None, 611 ) -> None: 612 """ 613 Save a new run record 614 615 Parameters 616 ---------- 617 episode_name : str 618 the name of the episode 619 parameters : dict 620 the parameters to save 621 behaviors_dict : dict 622 the dictionary of behaviors (keys are indices, values are names) 623 suppress_validation : bool, optional False 624 if True, existing episode with the same name will be overwritten 625 training_time : str, optional 626 the training time in '%H:%M:%S' format 627 """ 628 629 if not suppress_validation and episode_name in self.data.index: 630 raise ValueError(f"Episode {episode_name} already exists!") 631 pars = deepcopy(parameters) 632 if "meta" not in pars: 633 pars["meta"] = { 634 "time": strftime("%Y-%m-%d %H:%M:%S", localtime()), 635 "behaviors_dict": behaviors_dict, 636 } 637 else: 638 pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime()) 639 pars["meta"]["behaviors_dict"] = behaviors_dict 640 if training_time is not None: 641 pars["meta"]["training_time"] = training_time 642 if len(parameters.keys()) > 1: 643 pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {}) 644 for metric_name in pars["general"]["metric_functions"]: 645 pars[metric_name] = pars["metrics"].get(metric_name, {}) 646 if pars["general"].get("ssl", None) is not None: 647 for ssl_name in pars["general"]["ssl"]: 648 pars[ssl_name] = pars["ssl"].get(ssl_name, {}) 649 for group_name in ["metrics", "ssl"]: 650 if group_name in pars: 651 pars.pop(group_name) 652 data = { 653 (big_key, small_key): value 654 for big_key, big_value in pars.items() 655 for small_key, value in big_value.items() 656 } 657 list_keys = [] 658 with warnings.catch_warnings(): 659 warnings.filterwarnings("ignore", message="DataFrame is highly fragmented") 660 for k, v in data.items(): 661 if k not in self.data.columns: 662 self.data[k] = np.nan 663 if isinstance(v, list) and not isinstance(v, str): 664 list_keys.append(k) 665 for k in list_keys: 666 self.data[k] = self.data[k].astype(object) 667 self.data.loc[episode_name] = data 668 self._save()
Save a new run record
Parameters
episode_name : str the name of the episode parameters : dict the parameters to save behaviors_dict : dict the dictionary of behaviors (keys are indices, values are names) suppress_validation : bool, optional False if True, existing episode with the same name will be overwritten training_time : str, optional the training time in '%H:%M:%S' format
670 def load_parameters(self, episode_name: str) -> Dict: 671 """ 672 Load the task parameters from a record 673 674 Parameters 675 ---------- 676 episode_name : str 677 the name of the episode to load 678 679 Returns 680 ------- 681 parameters : dict 682 the loaded task parameters 683 """ 684 685 parameters = defaultdict(lambda: defaultdict(lambda: {})) 686 episode = self.data.loc[episode_name].dropna().to_dict() 687 keys = ["data", "augmentations", "general", "training", "model", "features"] 688 for key in episode: 689 big_key, small_key = key 690 if big_key in keys: 691 parameters[big_key][small_key] = episode[key] 692 # parameters = {k: dict(v) for k, v in parameters.items()} 693 ssl_keys = parameters["general"].get("ssl", None) 694 metric_keys = parameters["general"].get("metric_functions", None) 695 loss_key = parameters["general"]["loss_function"] 696 if ssl_keys is None: 697 ssl_keys = [] 698 if metric_keys is None: 699 metric_keys = [] 700 for key in episode: 701 big_key, small_key = key 702 if big_key in ssl_keys: 703 parameters["ssl"][big_key][small_key] = episode[key] 704 elif big_key in metric_keys: 705 parameters["metrics"][big_key][small_key] = episode[key] 706 elif big_key == "losses": 707 parameters["losses"][loss_key][small_key] = episode[key] 708 parameters = {k: dict(v) for k, v in parameters.items()} 709 parameters["general"]["num_classes"] = Run( 710 episode_name, self.project_path, params=self.data.loc[episode_name] 711 ).get_num_classes() 712 return parameters
Load the task parameters from a record
Parameters
episode_name : str the name of the episode to load
Returns
parameters : dict the loaded task parameters
714 def get_active_datasets(self) -> List: 715 """ 716 Get a list of names of datasets that are used by unfinished episodes 717 718 Returns 719 ------- 720 active_datasets : list 721 a list of dataset names used by unfinished episodes 722 """ 723 724 active_datasets = [] 725 for episode_name in self.unfinished_episodes(): 726 run = Run( 727 episode_name, self.project_path, params=self.data.loc[episode_name] 728 ) 729 active_datasets.append(run.dataset_name()) 730 return active_datasets
Get a list of names of datasets that are used by unfinished episodes
Returns
active_datasets : list a list of dataset names used by unfinished episodes
732 def list_episodes( 733 self, 734 episode_names: List = None, 735 value_filter: str = "", 736 display_parameters: List = None, 737 ) -> pd.DataFrame: 738 """ 739 Get a filtered pandas dataframe with episode metadata 740 741 Parameters 742 ---------- 743 episode_names : List 744 a list of strings of episode names 745 value_filter : str 746 a string of filters to apply of this general structure: 747 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 748 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' 749 display_parameters : List 750 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 751 752 Returns 753 ------- 754 pandas.DataFrame 755 the filtered dataframe 756 """ 757 758 if episode_names is not None: 759 data = deepcopy(self.data.loc[episode_names]) 760 else: 761 data = deepcopy(self.data) 762 if len(data) == 0: 763 return pd.DataFrame() 764 try: 765 filters = value_filter.split(",") 766 if filters == [""]: 767 filters = [] 768 for f in filters: 769 par_name, condition = f.split("::") 770 group_name, par_name = par_name.split("/") 771 sign, value = condition[0], condition[1:] 772 if value[0] == "=": 773 sign += "=" 774 value = value[1:] 775 try: 776 value = float(value) 777 except: 778 if value == "True": 779 value = True 780 elif value == "False": 781 value = False 782 elif value == "None": 783 value = None 784 if value is None: 785 if sign == "=": 786 data = data[data[group_name][par_name].isna()] 787 elif sign == "!=": 788 data = data[~data[group_name][par_name].isna()] 789 elif sign == ">": 790 data = data[data[group_name][par_name] > value] 791 elif sign == ">=": 792 data = data[data[group_name][par_name] >= value] 793 elif sign == "<": 794 data = data[data[group_name][par_name] < value] 795 elif sign == "<=": 796 data = data[data[group_name][par_name] <= value] 797 elif sign == "=": 798 data = data[data[group_name][par_name] == value] 799 elif sign == "!=": 800 data = data[data[group_name][par_name] != value] 801 else: 802 raise ValueError( 803 "Please use one of the signs: [>, <, >=, <=, =, !=]" 804 ) 805 except ValueError: 806 raise ValueError( 807 f"The {value_filter} filter is not valid, please use the following format:" 808 f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', " 809 f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'" 810 ) 811 if display_parameters is not None: 812 if type(display_parameters[0]) is str: 813 display_parameters = [ 814 (x.split("/")[0], x.split("/")[1]) for x in display_parameters 815 ] 816 display_parameters = [x for x in display_parameters if x in data.columns] 817 data = data[display_parameters] 818 return data
Get a filtered pandas dataframe with episode metadata
Parameters
episode_names : List a list of strings of episode names value_filter : str a string of filters to apply of this general structure: 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' display_parameters : List list of parameters to display (e.g. ['data/overlap', 'results/recall'])
Returns
pandas.DataFrame the filtered dataframe
820 def rename_episode(self, episode_name, new_episode_name): 821 if episode_name in self.data.index and new_episode_name not in self.data.index: 822 self.data.loc[new_episode_name] = self.data.loc[episode_name] 823 model_path = self.data.loc[new_episode_name, ("training", "model_path")] 824 self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join( 825 os.path.dirname(model_path), new_episode_name 826 ) 827 log_path = self.data.loc[new_episode_name, ("training", "log_file")] 828 self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join( 829 os.path.dirname(log_path), f"{new_episode_name}.txt" 830 ) 831 self.data = self.data.drop(index=episode_name) 832 self._save() 833 else: 834 raise ValueError("The names are wrong")
836 def remove_episode(self, episode_name: str) -> None: 837 """ 838 Remove all model, logs and metafile records related to an episode 839 840 Parameters 841 ---------- 842 episode_name : str 843 the name of the episode to remove 844 """ 845 846 if episode_name in self.data.index: 847 self.data = self.data.drop(index=episode_name) 848 self._save()
Remove all model, logs and metafile records related to an episode
Parameters
episode_name : str the name of the episode to remove
850 def unfinished_episodes(self) -> List: 851 """ 852 Get a list of unfinished episodes (currently running or interrupted) 853 854 Returns 855 ------- 856 interrupted_episodes: List 857 a list of string names of unfinished episodes in the records 858 """ 859 860 unfinished = [] 861 for name, params in self.data.iterrows(): 862 if Run(name, project_path=self.project_path, params=params).unfinished(): 863 unfinished.append(name) 864 return unfinished
Get a list of unfinished episodes (currently running or interrupted)
Returns
interrupted_episodes: List a list of string names of unfinished episodes in the records
866 def update_episode_results( 867 self, 868 episode_name: str, 869 logs: Tuple, 870 training_time: str = None, 871 ) -> None: 872 """ 873 Add results to an episode record 874 875 Parameters 876 ---------- 877 episode_name : str 878 the name of the episode to update 879 logs : dict 880 a log dictionary from task.train() 881 training_time : str 882 the training time 883 """ 884 885 metrics_log = logs[1] 886 results = {} 887 for key, value in metrics_log["val"].items(): 888 results[("results", key)] = value[-1] 889 if training_time is not None: 890 results[("meta", "training_time")] = training_time 891 for k, v in results.items(): 892 self.data.loc[episode_name, k] = v 893 self._save()
Add results to an episode record
Parameters
episode_name : str the name of the episode to update logs : dict a log dictionary from task.train() training_time : str the training time
895 def get_runs(self, episode_name: str) -> List: 896 """ 897 Get a list of runs with this episode name (episodes like `episode_name::0`) 898 899 Parameters 900 ---------- 901 episode_name : str 902 the name of the episode 903 904 Returns 905 ------- 906 runs_list : List 907 a list of string run names 908 """ 909 910 if episode_name is None: 911 return [] 912 index = self.data.index 913 runs_list = [] 914 for name in index: 915 if name.startswith(episode_name): 916 split = name.split("::") 917 if split[0] == episode_name: 918 if len(split) > 1 and split[-1].isnumeric() or len(split) == 1: 919 runs_list.append(name) 920 elif name == episode_name: 921 runs_list.append(name) 922 return runs_list
Get a list of runs with this episode name (episodes like episode_name::0
)
Parameters
episode_name : str the name of the episode
Returns
runs_list : List a list of string run names
932class Searches(SavedRuns): 933 """ 934 A class that manages operations with search records 935 """ 936 937 def save_search( 938 self, 939 search_name: str, 940 parameters: Dict, 941 n_trials: int, 942 best_params: Dict, 943 best_value: float, 944 metric: str, 945 search_space: Dict, 946 ) -> None: 947 """ 948 Save a new search record 949 950 Parameters 951 ---------- 952 search_name : str 953 the name of the search to save 954 parameters : dict 955 the task parameters to save 956 n_trials : int 957 the number of trials in the search 958 best_params : dict 959 the best parameters dictionary 960 best_value : float 961 the best valie 962 metric : str 963 the name of the objective metric 964 search_space : dict 965 a dictionary representing the search space; of this general structure: 966 {'group/param_name': ('float/int/float_log/int_log', start, end), 967 'group/param_name': ('categorical', [choices])}, e.g. 968 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 969 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])} 970 """ 971 972 pars = deepcopy(parameters) 973 pars["results"] = {"best_value": best_value, "best_params": best_params} 974 pars["meta"] = { 975 "objective": metric, 976 "n_trials": n_trials, 977 "search_space": search_space, 978 } 979 self.save_episode(search_name, pars, {}) 980 981 def get_best_params_raw(self, search_name: str) -> Dict: 982 """ 983 Get the raw dictionary of best parameters found by a search 984 985 Parameters 986 ---------- 987 search_name : str 988 the name of the search 989 990 Returns 991 ------- 992 best_params : dict 993 a dictionary of the best parameters where the keys are in '{group}/{name}' format 994 """ 995 996 return self.data.loc[search_name]["results"]["best_params"] 997 998 def get_best_params( 999 self, 1000 search_name: str, 1001 load_parameters: List = None, 1002 round_to_binary: List = None, 1003 ) -> Dict: 1004 """ 1005 Get the best parameters from a search 1006 1007 Parameters 1008 ---------- 1009 search_name : str 1010 the name of the search 1011 load_parameters : List, optional 1012 a list of string names of the parameters to load (if not provided all parameters are loaded) 1013 round_to_binary : List, optional 1014 a list of string names of the loaded parameters that should be rounded to the nearest power of two 1015 1016 Returns 1017 ------- 1018 best_params : dict 1019 a dictionary of the best parameters 1020 """ 1021 1022 if round_to_binary is None: 1023 round_to_binary = [] 1024 params = self.data.loc[search_name]["results"]["best_params"] 1025 if load_parameters is not None: 1026 params = {k: v for k, v in params.items() if k in load_parameters} 1027 for par_name in round_to_binary: 1028 if par_name not in params: 1029 continue 1030 if not isinstance(params[par_name], float) and not isinstance( 1031 params[par_name], int 1032 ): 1033 raise TypeError( 1034 f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two" 1035 ) 1036 i = 1 1037 while 2**i < params[par_name]: 1038 i += 1 1039 if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]: 1040 params[par_name] = 2 ** (i - 1) 1041 else: 1042 params[par_name] = 2**i 1043 res = defaultdict(lambda: defaultdict(lambda: {})) 1044 for k, v in params.items(): 1045 big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:]) 1046 if len(small_key.split("/")) == 1: 1047 res[big_key][small_key] = v 1048 else: 1049 group, key = small_key.split("/") 1050 res[big_key][group][key] = v 1051 model = self.data.loc[search_name]["general"]["model_name"] 1052 return res, model
A class that manages operations with search records
937 def save_search( 938 self, 939 search_name: str, 940 parameters: Dict, 941 n_trials: int, 942 best_params: Dict, 943 best_value: float, 944 metric: str, 945 search_space: Dict, 946 ) -> None: 947 """ 948 Save a new search record 949 950 Parameters 951 ---------- 952 search_name : str 953 the name of the search to save 954 parameters : dict 955 the task parameters to save 956 n_trials : int 957 the number of trials in the search 958 best_params : dict 959 the best parameters dictionary 960 best_value : float 961 the best valie 962 metric : str 963 the name of the objective metric 964 search_space : dict 965 a dictionary representing the search space; of this general structure: 966 {'group/param_name': ('float/int/float_log/int_log', start, end), 967 'group/param_name': ('categorical', [choices])}, e.g. 968 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 969 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])} 970 """ 971 972 pars = deepcopy(parameters) 973 pars["results"] = {"best_value": best_value, "best_params": best_params} 974 pars["meta"] = { 975 "objective": metric, 976 "n_trials": n_trials, 977 "search_space": search_space, 978 } 979 self.save_episode(search_name, pars, {})
Save a new search record
Parameters
search_name : str the name of the search to save parameters : dict the task parameters to save n_trials : int the number of trials in the search best_params : dict the best parameters dictionary best_value : float the best valie metric : str the name of the objective metric search_space : dict a dictionary representing the search space; of this general structure: {'group/param_name': ('float/int/float_log/int_log', start, end), 'group/param_name': ('categorical', [choices])}, e.g. {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
981 def get_best_params_raw(self, search_name: str) -> Dict: 982 """ 983 Get the raw dictionary of best parameters found by a search 984 985 Parameters 986 ---------- 987 search_name : str 988 the name of the search 989 990 Returns 991 ------- 992 best_params : dict 993 a dictionary of the best parameters where the keys are in '{group}/{name}' format 994 """ 995 996 return self.data.loc[search_name]["results"]["best_params"]
Get the raw dictionary of best parameters found by a search
Parameters
search_name : str the name of the search
Returns
best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format
998 def get_best_params( 999 self, 1000 search_name: str, 1001 load_parameters: List = None, 1002 round_to_binary: List = None, 1003 ) -> Dict: 1004 """ 1005 Get the best parameters from a search 1006 1007 Parameters 1008 ---------- 1009 search_name : str 1010 the name of the search 1011 load_parameters : List, optional 1012 a list of string names of the parameters to load (if not provided all parameters are loaded) 1013 round_to_binary : List, optional 1014 a list of string names of the loaded parameters that should be rounded to the nearest power of two 1015 1016 Returns 1017 ------- 1018 best_params : dict 1019 a dictionary of the best parameters 1020 """ 1021 1022 if round_to_binary is None: 1023 round_to_binary = [] 1024 params = self.data.loc[search_name]["results"]["best_params"] 1025 if load_parameters is not None: 1026 params = {k: v for k, v in params.items() if k in load_parameters} 1027 for par_name in round_to_binary: 1028 if par_name not in params: 1029 continue 1030 if not isinstance(params[par_name], float) and not isinstance( 1031 params[par_name], int 1032 ): 1033 raise TypeError( 1034 f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two" 1035 ) 1036 i = 1 1037 while 2**i < params[par_name]: 1038 i += 1 1039 if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]: 1040 params[par_name] = 2 ** (i - 1) 1041 else: 1042 params[par_name] = 2**i 1043 res = defaultdict(lambda: defaultdict(lambda: {})) 1044 for k, v in params.items(): 1045 big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:]) 1046 if len(small_key.split("/")) == 1: 1047 res[big_key][small_key] = v 1048 else: 1049 group, key = small_key.split("/") 1050 res[big_key][group][key] = v 1051 model = self.data.loc[search_name]["general"]["model_name"] 1052 return res, model
Get the best parameters from a search
Parameters
search_name : str the name of the search load_parameters : List, optional a list of string names of the parameters to load (if not provided all parameters are loaded) round_to_binary : List, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two
Returns
best_params : dict a dictionary of the best parameters
1055class Suggestions(SavedRuns): 1056 def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters): 1057 pars = deepcopy(parameters) 1058 pars["meta"] = meta_parameters 1059 super().save_episode(episode_name, pars, behaviors_dict=None)
A class that manages operations with all episode (or prediction) records
1062class SavedStores: 1063 """ 1064 A class that manages operation with saved dataset records 1065 """ 1066 1067 def __init__(self, path): 1068 """ 1069 Parameters 1070 ---------- 1071 path : str 1072 the path to the pickled SavedRuns dataframe 1073 """ 1074 1075 self.path = path 1076 self.data = pd.read_pickle(path) 1077 self.skip_keys = [ 1078 "feature_save_path", 1079 "saved_data_path", 1080 "real_lens", 1081 "recompute_annotation", 1082 ] 1083 1084 def clear(self) -> None: 1085 """ 1086 Remove all datasets 1087 """ 1088 1089 for dataset_name in self.data.index: 1090 self.remove_dataset(dataset_name) 1091 1092 def dataset_names(self) -> List: 1093 """ 1094 Get a list of dataset names 1095 1096 Returns 1097 ------- 1098 dataset_names : List 1099 a list of string dataset names 1100 """ 1101 1102 return list(self.data.index) 1103 1104 def remove(self, names: List) -> None: 1105 """ 1106 Remove some datasets 1107 1108 Parameters 1109 ---------- 1110 names : List 1111 a list of string names of the datasets to delete 1112 """ 1113 1114 for dataset_name in names: 1115 if dataset_name in self.data.index: 1116 self.remove_dataset(dataset_name) 1117 1118 def remove_dataset(self, dataset_name: str) -> None: 1119 """ 1120 Remove a dataset record 1121 1122 Parameters 1123 ---------- 1124 dataset_name : str 1125 the name of the dataset to remove 1126 """ 1127 1128 if dataset_name in self.data.index: 1129 self.data = self.data.drop(index=dataset_name) 1130 self._save() 1131 1132 def find_name(self, parameters: Dict) -> str: 1133 """ 1134 Find a record that satisfies the parameters (if it exists) 1135 1136 Parameters 1137 ---------- 1138 parameters : dict 1139 a dictionary of data parameters 1140 1141 Returns 1142 ------- 1143 name : str 1144 the name of a record that has the same parameters (None if it does not exist; the earliest if there are 1145 several) 1146 """ 1147 1148 filter = deepcopy(parameters) 1149 for key, value in parameters.items(): 1150 if value is None or key in self.skip_keys: 1151 filter.pop(key) 1152 elif key not in self.data.columns: 1153 return None 1154 saved_annotation = self.data[ 1155 (self.data[list(filter)] == pd.Series(filter)).all(axis=1) 1156 ] 1157 for i in range(len(saved_annotation)): 1158 ok = True 1159 for key in saved_annotation.columns: 1160 if key in self.skip_keys: 1161 continue 1162 isnull = pd.isnull(saved_annotation.iloc[i][key]) 1163 if not isinstance(isnull, bool): 1164 isnull = False 1165 if key not in filter and not isnull: 1166 ok = False 1167 if ok: 1168 name = saved_annotation.iloc[i].name 1169 return name 1170 return None 1171 1172 def save_store(self, episode_name: str, parameters: Dict) -> None: 1173 """ 1174 Save a new saved dataset record 1175 1176 Parameters 1177 ---------- 1178 episode_name : str 1179 the name of the dataset 1180 parameters : dict 1181 a dictionary of data parameters 1182 """ 1183 1184 pars = deepcopy(parameters) 1185 for k, v in parameters.items(): 1186 if k not in self.data.columns: 1187 self.data[k] = np.nan 1188 if self.find_name(pars) is None: 1189 self.data.loc[episode_name] = pars 1190 self._save() 1191 1192 def _save(self): 1193 """ 1194 Save the dataframe 1195 """ 1196 1197 self.data.to_pickle(self.path) 1198 1199 def check_name_validity(self, store_name: str) -> bool: 1200 """ 1201 Check if a store name already exists 1202 1203 Parameters 1204 ---------- 1205 episode_name : str 1206 the name to check 1207 1208 Returns 1209 ------- 1210 result : bool 1211 True if the name can be used 1212 """ 1213 1214 if store_name in self.data.index: 1215 return False 1216 else: 1217 return True
A class that manages operation with saved dataset records
1067 def __init__(self, path): 1068 """ 1069 Parameters 1070 ---------- 1071 path : str 1072 the path to the pickled SavedRuns dataframe 1073 """ 1074 1075 self.path = path 1076 self.data = pd.read_pickle(path) 1077 self.skip_keys = [ 1078 "feature_save_path", 1079 "saved_data_path", 1080 "real_lens", 1081 "recompute_annotation", 1082 ]
Parameters
path : str the path to the pickled SavedRuns dataframe
1084 def clear(self) -> None: 1085 """ 1086 Remove all datasets 1087 """ 1088 1089 for dataset_name in self.data.index: 1090 self.remove_dataset(dataset_name)
Remove all datasets
1092 def dataset_names(self) -> List: 1093 """ 1094 Get a list of dataset names 1095 1096 Returns 1097 ------- 1098 dataset_names : List 1099 a list of string dataset names 1100 """ 1101 1102 return list(self.data.index)
Get a list of dataset names
Returns
dataset_names : List a list of string dataset names
1104 def remove(self, names: List) -> None: 1105 """ 1106 Remove some datasets 1107 1108 Parameters 1109 ---------- 1110 names : List 1111 a list of string names of the datasets to delete 1112 """ 1113 1114 for dataset_name in names: 1115 if dataset_name in self.data.index: 1116 self.remove_dataset(dataset_name)
Remove some datasets
Parameters
names : List a list of string names of the datasets to delete
1118 def remove_dataset(self, dataset_name: str) -> None: 1119 """ 1120 Remove a dataset record 1121 1122 Parameters 1123 ---------- 1124 dataset_name : str 1125 the name of the dataset to remove 1126 """ 1127 1128 if dataset_name in self.data.index: 1129 self.data = self.data.drop(index=dataset_name) 1130 self._save()
Remove a dataset record
Parameters
dataset_name : str the name of the dataset to remove
1132 def find_name(self, parameters: Dict) -> str: 1133 """ 1134 Find a record that satisfies the parameters (if it exists) 1135 1136 Parameters 1137 ---------- 1138 parameters : dict 1139 a dictionary of data parameters 1140 1141 Returns 1142 ------- 1143 name : str 1144 the name of a record that has the same parameters (None if it does not exist; the earliest if there are 1145 several) 1146 """ 1147 1148 filter = deepcopy(parameters) 1149 for key, value in parameters.items(): 1150 if value is None or key in self.skip_keys: 1151 filter.pop(key) 1152 elif key not in self.data.columns: 1153 return None 1154 saved_annotation = self.data[ 1155 (self.data[list(filter)] == pd.Series(filter)).all(axis=1) 1156 ] 1157 for i in range(len(saved_annotation)): 1158 ok = True 1159 for key in saved_annotation.columns: 1160 if key in self.skip_keys: 1161 continue 1162 isnull = pd.isnull(saved_annotation.iloc[i][key]) 1163 if not isinstance(isnull, bool): 1164 isnull = False 1165 if key not in filter and not isnull: 1166 ok = False 1167 if ok: 1168 name = saved_annotation.iloc[i].name 1169 return name 1170 return None
Find a record that satisfies the parameters (if it exists)
Parameters
parameters : dict a dictionary of data parameters
Returns
name : str the name of a record that has the same parameters (None if it does not exist; the earliest if there are several)
1172 def save_store(self, episode_name: str, parameters: Dict) -> None: 1173 """ 1174 Save a new saved dataset record 1175 1176 Parameters 1177 ---------- 1178 episode_name : str 1179 the name of the dataset 1180 parameters : dict 1181 a dictionary of data parameters 1182 """ 1183 1184 pars = deepcopy(parameters) 1185 for k, v in parameters.items(): 1186 if k not in self.data.columns: 1187 self.data[k] = np.nan 1188 if self.find_name(pars) is None: 1189 self.data.loc[episode_name] = pars 1190 self._save()
Save a new saved dataset record
Parameters
episode_name : str the name of the dataset parameters : dict a dictionary of data parameters
1199 def check_name_validity(self, store_name: str) -> bool: 1200 """ 1201 Check if a store name already exists 1202 1203 Parameters 1204 ---------- 1205 episode_name : str 1206 the name to check 1207 1208 Returns 1209 ------- 1210 result : bool 1211 True if the name can be used 1212 """ 1213 1214 if store_name in self.data.index: 1215 return False 1216 else: 1217 return True
Check if a store name already exists
Parameters
episode_name : str the name to check
Returns
result : bool True if the name can be used