dlc2action.project.meta
Handling meta (history) files.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Handling meta (history) files.""" 8 9import ast 10import os 11import warnings 12from collections import defaultdict 13from copy import deepcopy 14from time import localtime, strftime 15from typing import Dict, List, Set, Tuple, Union 16 17import numpy as np 18import pandas as pd 19from dlc2action.utils import correct_path 20from abc import abstractmethod 21 22import re 23 24class Run: 25 """A class that manages operations with a single episode record.""" 26 27 def __init__( 28 self, 29 episode_name: str, 30 project_path: str, 31 meta_path: str = None, 32 params: Dict = None, 33 ): 34 """Initialize the class. 35 36 Parameters 37 ---------- 38 episode_name : str 39 the name of the episode 40 project_path : str 41 the path to the project folder 42 meta_path : str, optional 43 the path to the pickled SavedRuns dataframe 44 params : dict, optional 45 alternative to meta_path: pre-loaded pandas Series of episode parameters 46 47 """ 48 self.name = episode_name 49 self.project_path = project_path 50 if meta_path is not None: 51 try: 52 self.params = pd.read_pickle(meta_path).loc[episode_name] 53 except: 54 raise ValueError(f"The {episode_name} episode does not exist!") 55 elif params is not None: 56 self.params = params 57 else: 58 raise ValueError("Either meta_path or params has to be not None") 59 self.params = self._check_str_conversion() 60 61 def _check_str_conversion(self): 62 """Check if the parameters are in string format and convert them to the correct type.""" 63 return _check_str_conversion(self.params) 64 65 def training_time(self) -> int: 66 """Get the training time in seconds. 67 68 Returns 69 ------- 70 training_time : int 71 the training time in seconds 72 73 """ 74 time_str = self.params["meta"].get("training_time") 75 try: 76 if time_str is None or np.isnan(time_str): 77 return np.nan 78 except TypeError: 79 pass 80 h, m, s = time_str.split(":") 81 seconds = int(h) * 3600 + int(m) * 60 + int(s) 82 return seconds 83 84 def model_file(self, load_epoch: int = None) -> str: 85 """Get a checkpoint file path. 86 87 Parameters 88 ---------- 89 load_epoch : int, optional 90 the epoch to load (the closest checkpoint will be chosen; if not given will be set to last) 91 92 Returns 93 ------- 94 checkpoint_path : str 95 the path to the checkpoint 96 97 """ 98 model_path = correct_path( 99 self.params["training"]["model_save_path"], self.project_path 100 ) 101 if load_epoch is None: 102 model_file = sorted(os.listdir(model_path))[-1] 103 else: 104 model_files = os.listdir(model_path) 105 if len(model_files) == 0: 106 model_file = None 107 else: 108 epochs = [int(file[5:].split(".")[0]) for file in model_files] 109 diffs = [np.abs(epoch - load_epoch) for epoch in epochs] 110 argmin = np.argmin(diffs) 111 model_file = model_files[argmin] 112 model_file = os.path.join(model_path, model_file) 113 return model_file 114 115 def dataset_name(self) -> str: 116 """Get the dataset name. 117 118 Returns 119 ------- 120 dataset_name : str 121 the name of the dataset record 122 123 """ 124 data_path = correct_path( 125 self.params["data"]["feature_save_path"], self.project_path 126 ) 127 dataset_name = os.path.basename(data_path) 128 return dataset_name 129 130 def split_file(self) -> str: 131 """Get the split file. 132 133 Returns 134 ------- 135 split_path : str 136 the path to the split file 137 138 """ 139 return correct_path(self.params["training"]["split_path"], self.project_path) 140 141 def log_file(self) -> str: 142 """Get the log file. 143 144 Returns 145 ------- 146 log_path : str 147 the path to the log file 148 149 """ 150 return correct_path(self.params["training"]["log_file"], self.project_path) 151 152 def split_info(self) -> Dict: 153 """Get the train/test/val split information. 154 155 Returns 156 ------- 157 split_info : dict 158 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values 159 160 """ 161 val_frac = self.params["training"]["val_frac"] 162 test_frac = self.params["training"]["test_frac"] 163 partition_method = self.params["training"]["partition_method"] 164 return { 165 "val_frac": val_frac, 166 "test_frac": test_frac, 167 "partition_method": partition_method, 168 } 169 170 def same_split_info(self, split_info: Dict) -> bool: 171 """Check whether this episode has the same split information. 172 173 Parameters 174 ---------- 175 split_info : dict 176 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode 177 178 Returns 179 ------- 180 result : bool 181 if True, this episode has the same split information 182 183 """ 184 self_split_info = self.split_info() 185 for k in ["val_frac", "test_frac", "partition_method"]: 186 if self_split_info[k] != split_info[k]: 187 return False 188 return True 189 190 def get_metrics(self) -> List: 191 """Get a list of tracked metrics. 192 193 Returns 194 ------- 195 metrics : list 196 a list of tracked metric names 197 198 """ 199 return self.params["general"]["metric_functions"] 200 201 def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray: 202 """Get the metric log. 203 204 Parameters 205 ---------- 206 mode : {'train', 'val'} 207 the mode to get the log from 208 metric_name : str 209 the metric to get the log for (has to be one of the metric computed for this episode during training) 210 211 Returns 212 ------- 213 log : np.ndarray 214 the log of metric values (empty if the metric was not computed during training) 215 216 """ 217 metric_array = [] 218 with open(self.log_file()) as f: 219 for line in f.readlines(): 220 if mode == "train" and line.startswith("[epoch"): 221 line = line.split("]: ")[1] 222 elif mode == "val" and line.startswith("validation"): 223 line = line.split("validation: ")[1] 224 else: 225 continue 226 metrics = line.split(", ") 227 228 metric_ind = np.where( 229 np.array([m.split()[0] for m in metrics]) == metric_name 230 )[0] 231 if len(metric_ind): 232 name, value = metrics[metric_ind[0]].split() 233 metric_array.append(float(value)) 234 else: 235 metric_inds = [ 236 m for m in metrics if m.split()[0].split("_")[0] == metric_name 237 ] 238 if len(metric_inds): 239 beh_metrics_avg = np.mean( 240 [float(m.split()[1]) for m in metric_inds] 241 ) 242 metric_array.append(beh_metrics_avg) 243 244 return np.array(metric_array) 245 246 def get_epoch_list(self, mode) -> List: 247 """Get a list of epoch indices. 248 249 Parameters 250 ---------- 251 mode : {'train', 'val'} 252 the mode to get the epoch list for 253 254 Returns 255 ------- 256 epoch_list : list 257 a list of int epoch indices 258 259 """ 260 epoch_list = [] 261 with open(self.log_file()) as f: 262 for line in f.readlines(): 263 if line.startswith("[epoch"): 264 epoch = int(line[7:].split("]:")[0]) 265 if mode == "train": 266 epoch_list.append(epoch) 267 elif mode == "val": 268 epoch_list.append(epoch) 269 return epoch_list 270 271 def get_metrics(self) -> List: 272 """Get a list of metric names in the episode log. 273 274 Returns 275 ------- 276 metrics : List 277 a list of string metric names 278 279 """ 280 metrics = [] 281 with open(self.log_file()) as f: 282 for line in f.readlines(): 283 if line.startswith("[epoch"): 284 line = line.split("]: ")[1] 285 elif line.startswith("validation"): 286 line = line.split("validation: ")[1] 287 else: 288 continue 289 metric_logs = line.split(", ") 290 for metric in metric_logs: 291 name, _ = metric.split() 292 metrics.append(name) 293 break 294 return metrics 295 296 def unfinished(self) -> bool: 297 """Check whether this episode was interrupted. 298 299 Returns 300 ------- 301 result : bool 302 True if the number of epochs in the log file is smaller than in the parameters 303 304 """ 305 num_epoch_theor = self.params["training"]["num_epochs"] 306 log_file = self.log_file() 307 if not isinstance(log_file, str): 308 return False 309 if not os.path.exists(log_file): 310 return True 311 with open(self.log_file()) as f: 312 num_epoch = 0 313 val = False 314 for line in f.readlines(): 315 num_epoch += 1 316 if num_epoch == 2 and line.startswith("validation"): 317 val = True 318 if val: 319 num_epoch //= 2 320 return num_epoch < num_epoch_theor 321 322 def get_class_ind(self, class_name: str) -> int: 323 """Get the integer label from a class name. 324 325 Parameters 326 ---------- 327 class_name : str 328 the name of the class 329 330 Returns 331 ------- 332 class_ind : int 333 the integer label 334 335 """ 336 behaviors_dict = self.params["meta"]["behaviors_dict"] 337 for k, v in behaviors_dict.items(): 338 if v == class_name: 339 return k 340 raise ValueError( 341 f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})" 342 ) 343 344 def get_behaviors_dict(self) -> Dict: 345 """Get behaviors dictionary in the episode. 346 347 Returns 348 ------- 349 behaviors_dict : dict 350 a dictionary with class indices as keys and labels as values 351 352 """ 353 behavior_dict = self.params["meta"]["behaviors_dict"] 354 if isinstance(behavior_dict, str): 355 behavior_dict = ast.literal_eval(behavior_dict) 356 357 return behavior_dict 358 359 def get_num_classes(self) -> int: 360 """Get number of classes in episode. 361 362 Returns 363 ------- 364 num_classes : int 365 the number of classes 366 367 """ 368 return len(self.params["meta"]["behaviors_dict"]) 369 370 371class DecisionThresholds: 372 """A class that saves and looks up tuned decision thresholds.""" 373 374 def __init__(self, path: str) -> None: 375 """Initialize the class. 376 377 Parameters 378 ---------- 379 path : str 380 the path to the pickled SavedRuns dataframe 381 382 """ 383 self.path = path 384 self.data = pd.read_pickle(path) 385 386 def save_thresholds( 387 self, 388 episode_names: List, 389 epochs: List, 390 metric_name: str, 391 metric_parameters: Dict, 392 thresholds: List, 393 ) -> None: 394 """Add a new record. 395 396 Parameters 397 ---------- 398 episode_names : list 399 the names of the episodes 400 epochs : int 401 the epoch index list 402 metric_name : str 403 the name of the metric the thresholds were tuned on 404 metric_parameters : dict 405 the metric parameter dictionary 406 thresholds : list 407 a list of float decision thresholds 408 409 """ 410 episodes = set(zip(episode_names, epochs)) 411 for key in ["average", "threshold_value", "ignored_classes"]: 412 if key in metric_parameters: 413 metric_parameters.pop(key) 414 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 415 parameters["thresholds"] = thresholds 416 parameters["episodes"] = episodes 417 pars = {k: [v] for k, v in parameters.items()} 418 self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0) 419 self._save() 420 421 def find_thresholds( 422 self, 423 episode_names: List, 424 epochs: List, 425 metric_name: str, 426 metric_parameters: Dict, 427 ) -> Union[List, None]: 428 """Find a record. 429 430 Parameters 431 ---------- 432 episode_names : list 433 the names of the episodes 434 epochs : list 435 the epoch index list 436 metric_name : str 437 the name of the metric the thresholds were tuned on 438 metric_parameters : dict 439 the metric parameter dictionary 440 441 Returns 442 ------- 443 thresholds : list 444 a list of float decision thresholds 445 446 """ 447 episodes = set(zip(episode_names, epochs)) 448 for key in ["average", "threshold_value", "ignored_classes"]: 449 if key in metric_parameters: 450 metric_parameters.pop(key) 451 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 452 parameters["episodes"] = episodes 453 filter = deepcopy(parameters) 454 for key, value in parameters.items(): 455 if value is None: 456 filter.pop(key) 457 elif key not in self.data.columns: 458 return None 459 data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)] 460 if len(data) > 0: 461 thresholds = data.iloc[0]["thresholds"] 462 return thresholds 463 else: 464 return None 465 466 def _save(self) -> None: 467 """Save the records.""" 468 self.data.copy().to_pickle(self.path) 469 470 471class SavedRuns: 472 """A class that manages operations with all episode (or prediction) records.""" 473 474 def __init__(self, path: str, project_path: str) -> None: 475 """Initialize the class. 476 477 Parameters 478 ---------- 479 path : str 480 the path to the pickled SavedRuns dataframe 481 project_path : str 482 the path to the project folder 483 484 """ 485 self.path = path 486 self.project_path = project_path 487 self.data = pd.read_pickle(path) 488 self.data = _check_str_conversion(self.data) 489 490 def update( 491 self, 492 data: pd.DataFrame, 493 data_path: str, 494 annotation_path: str, 495 name_map: Dict = None, 496 force: bool = False, 497 ) -> None: 498 """Update with new data. 499 500 Parameters 501 ---------- 502 data : pd.DataFrame 503 the new dataframe 504 data_path : str 505 the new data path 506 annotation_path : str 507 the new annotation path 508 name_map : dict, optional 509 the name change dictionary; keys are old episode names and values are new episode names 510 force : bool, default False 511 replace existing episodes if `True` 512 513 """ 514 if name_map is None: 515 name_map = {} 516 data = data.rename(index=name_map) 517 for episode in data.index: 518 new_model = os.path.join(self.project_path, "results", "model", episode) 519 data.loc[episode, ("training", "model_save_path")] = new_model 520 new_log = os.path.join( 521 self.project_path, "results", "logs", f"{episode}.txt" 522 ) 523 data.loc[episode, ("training", "log_file")] = new_log 524 old_split = data.loc[episode, ("training", "split_path")] 525 if old_split is None: 526 new_split = None 527 else: 528 new_split = os.path.join( 529 self.project_path, "results", "splits", os.path.basename(old_split) 530 ) 531 data.loc[episode, ("training", "split_path")] = new_split 532 data.loc[episode, ("data", "data_path")] = data_path 533 data.loc[episode, ("data", "annotation_path")] = annotation_path 534 if episode in self.data.index: 535 if force: 536 self.data = self.data.drop(index=[episode]) 537 else: 538 raise RuntimeError(f"The {episode} episode name is already taken!") 539 self.data = pd.concat([self.data, data]) 540 self._save() 541 542 def get_subset(self, episode_names: List) -> pd.DataFrame: 543 """Get a subset of the raw metadata. 544 545 Parameters 546 ---------- 547 episode_names : list 548 a list of the episodes to include 549 550 Returns 551 ------- 552 subset : pd.DataFrame 553 the subset of the raw metadata 554 555 """ 556 for episode in episode_names: 557 if episode not in self.data.index: 558 raise ValueError( 559 f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records" 560 ) 561 return self.data.loc[episode_names] 562 563 def get_saved_data_path(self, episode_name: str) -> str: 564 """Get the `saved_data_path` parameter for the episode. 565 566 Parameters 567 ---------- 568 episode_name : str 569 the name of the episode 570 571 Returns 572 ------- 573 saved_data_path : str 574 the saved data path 575 576 """ 577 return self.data.loc[episode_name]["data"]["saved_data_path"] 578 579 def check_name_validity(self, episode_name: str) -> bool: 580 """Check if an episode name already exists. 581 582 Parameters 583 ---------- 584 episode_name : str 585 the name to check 586 587 Returns 588 ------- 589 result : bool 590 True if the name can be used 591 592 """ 593 if episode_name in self.data.index: 594 return False 595 else: 596 return True 597 598 def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None: 599 """Update meta data with evaluation results. 600 601 Parameters 602 ---------- 603 episode_name : str 604 the name of the episode to update 605 metrics : dict 606 a dictionary of the metrics 607 608 """ 609 for key, value in metrics.items(): 610 self.data.loc[episode_name, ("results", key)] = value 611 self._save() 612 613 def save_episode( 614 self, 615 episode_name: str, 616 parameters: Dict, 617 behaviors_dict: Dict, 618 suppress_validation: bool = False, 619 training_time: str = None, 620 ) -> None: 621 """Save a new run record. 622 623 Parameters 624 ---------- 625 episode_name : str 626 the name of the episode 627 parameters : dict 628 the parameters to save 629 behaviors_dict : dict 630 the dictionary of behaviors (keys are indices, values are names) 631 suppress_validation : bool, optional False 632 if True, existing episode with the same name will be overwritten 633 training_time : str, optional 634 the training time in '%H:%M:%S' format 635 636 """ 637 if not suppress_validation and episode_name in self.data.index: 638 raise ValueError(f"Episode {episode_name} already exists!") 639 pars = deepcopy(parameters) 640 if "meta" not in pars: 641 pars["meta"] = { 642 "time": strftime("%Y-%m-%d %H:%M:%S", localtime()), 643 "behaviors_dict": behaviors_dict, 644 } 645 else: 646 pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime()) 647 pars["meta"]["behaviors_dict"] = behaviors_dict 648 if training_time is not None: 649 pars["meta"]["training_time"] = training_time 650 if len(parameters.keys()) > 1: 651 pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {}) 652 for metric_name in pars["general"]["metric_functions"]: 653 pars[metric_name] = pars["metrics"].get(metric_name, {}) 654 if pars["general"].get("ssl", None) is not None: 655 for ssl_name in pars["general"]["ssl"]: 656 pars[ssl_name] = pars["ssl"].get(ssl_name, {}) 657 for group_name in ["metrics", "ssl"]: 658 if group_name in pars: 659 pars.pop(group_name) 660 data = { 661 (big_key, small_key): value 662 for big_key, big_value in pars.items() 663 for small_key, value in big_value.items() 664 } 665 list_keys = [] 666 with warnings.catch_warnings(): 667 warnings.filterwarnings("ignore", message="DataFrame is highly fragmented") 668 for k, v in data.items(): 669 if k not in self.data.columns: 670 self.data[k] = np.nan 671 if isinstance(v, list) and not isinstance(v, str): 672 list_keys.append(k) 673 for k in list_keys: 674 self.data[k] = self.data[k].astype(object) 675 self.data.loc[episode_name] = data 676 self._save() 677 678 def load_parameters(self, episode_name: str) -> Dict: 679 """Load the task parameters from a record. 680 681 Parameters 682 ---------- 683 episode_name : str 684 the name of the episode to load 685 686 Returns 687 ------- 688 parameters : dict 689 the loaded task parameters 690 691 """ 692 parameters = defaultdict(lambda: defaultdict(lambda: {})) 693 episode = self.data.loc[episode_name].dropna().to_dict() 694 keys = ["data", "augmentations", "general", "training", "model", "features"] 695 for key in episode: 696 big_key, small_key = key 697 if big_key in keys: 698 parameters[big_key][small_key] = episode[key] 699 # parameters = {k: dict(v) for k, v in parameters.items()} 700 ssl_keys = parameters["general"].get("ssl", None) 701 metric_keys = parameters["general"].get("metric_functions", None) 702 loss_key = parameters["general"]["loss_function"] 703 if ssl_keys is None: 704 ssl_keys = [] 705 if metric_keys is None: 706 metric_keys = [] 707 for key in episode: 708 big_key, small_key = key 709 if big_key in ssl_keys: 710 parameters["ssl"][big_key][small_key] = episode[key] 711 elif big_key in metric_keys: 712 parameters["metrics"][big_key][small_key] = episode[key] 713 elif big_key == "losses": 714 parameters["losses"][loss_key][small_key] = episode[key] 715 parameters = {k: dict(v) for k, v in parameters.items()} 716 parameters["general"]["num_classes"] = Run( 717 episode_name, self.project_path, params=self.data.loc[episode_name] 718 ).get_num_classes() 719 return parameters 720 721 def get_active_datasets(self) -> List: 722 """Get a list of names of datasets that are used by unfinished episodes. 723 724 Returns 725 ------- 726 active_datasets : list 727 a list of dataset names used by unfinished episodes 728 729 """ 730 active_datasets = [] 731 for episode_name in self.unfinished_episodes(): 732 run = Run( 733 episode_name, self.project_path, params=self.data.loc[episode_name] 734 ) 735 active_datasets.append(run.dataset_name()) 736 return active_datasets 737 738 def list_episodes( 739 self, 740 episode_names: List = None, 741 value_filter: str = "", 742 display_parameters: List = None, 743 ) -> pd.DataFrame: 744 """Get a filtered pandas dataframe with episode metadata. 745 746 Parameters 747 ---------- 748 episode_names : List 749 a list of strings of episode names 750 value_filter : str 751 a string of filters to apply of this general structure: 752 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 753 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' 754 display_parameters : List 755 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 756 757 Returns 758 ------- 759 pandas.DataFrame 760 the filtered dataframe 761 762 """ 763 if episode_names is not None: 764 data = deepcopy(self.data.loc[episode_names]) 765 else: 766 data = deepcopy(self.data) 767 if len(data) == 0: 768 return pd.DataFrame() 769 try: 770 filters = value_filter.split(",") 771 if filters == [""]: 772 filters = [] 773 for f in filters: 774 par_name, condition = f.split("::") 775 group_name, par_name = par_name.split("/") 776 sign, value = condition[0], condition[1:] 777 if value[0] == "=": 778 sign += "=" 779 value = value[1:] 780 try: 781 value = float(value) 782 except: 783 if value == "True": 784 value = True 785 elif value == "False": 786 value = False 787 elif value == "None": 788 value = None 789 if value is None: 790 if sign == "=": 791 data = data[data[group_name][par_name].isna()] 792 elif sign == "!=": 793 data = data[~data[group_name][par_name].isna()] 794 elif sign == ">": 795 data = data[data[group_name][par_name] > value] 796 elif sign == ">=": 797 data = data[data[group_name][par_name] >= value] 798 elif sign == "<": 799 data = data[data[group_name][par_name] < value] 800 elif sign == "<=": 801 data = data[data[group_name][par_name] <= value] 802 elif sign == "=": 803 data = data[data[group_name][par_name] == value] 804 elif sign == "!=": 805 data = data[data[group_name][par_name] != value] 806 else: 807 raise ValueError( 808 "Please use one of the signs: [>, <, >=, <=, =, !=]" 809 ) 810 except ValueError: 811 raise ValueError( 812 f"The {value_filter} filter is not valid, please use the following format:" 813 f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', " 814 f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'" 815 ) 816 if display_parameters is not None: 817 if type(display_parameters[0]) is str: 818 display_parameters = [ 819 (x.split("/")[0], x.split("/")[1]) for x in display_parameters 820 ] 821 display_parameters = [x for x in display_parameters if x in data.columns] 822 data = data[display_parameters] 823 return data 824 825 def rename_episode(self, episode_name, new_episode_name): 826 """Rename an episode. 827 828 Parameters 829 ---------- 830 episode_name : str 831 the name of the episode to rename 832 new_episode_name : str 833 the new name of the episode 834 835 """ 836 if episode_name in self.data.index and new_episode_name not in self.data.index: 837 self.data.loc[new_episode_name] = self.data.loc[episode_name] 838 model_path = self.data.loc[new_episode_name, ("training", "model_path")] 839 self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join( 840 os.path.dirname(model_path), new_episode_name 841 ) 842 log_path = self.data.loc[new_episode_name, ("training", "log_file")] 843 self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join( 844 os.path.dirname(log_path), f"{new_episode_name}.txt" 845 ) 846 self.data = self.data.drop(index=episode_name) 847 self._save() 848 else: 849 raise ValueError("The names are wrong") 850 851 def remove_episode(self, episode_name: str) -> None: 852 """Remove all model, logs and metafile records related to an episode. 853 854 Parameters 855 ---------- 856 episode_name : str 857 the name of the episode to remove 858 859 """ 860 if episode_name in self.data.index: 861 self.data = self.data.drop(index=episode_name) 862 self._save() 863 864 def unfinished_episodes(self) -> List: 865 """Get a list of unfinished episodes (currently running or interrupted). 866 867 Returns 868 ------- 869 interrupted_episodes: List 870 a list of string names of unfinished episodes in the records 871 872 """ 873 unfinished = [] 874 for name, params in self.data.iterrows(): 875 if Run(name, project_path=self.project_path, params=params).unfinished(): 876 unfinished.append(name) 877 return unfinished 878 879 def update_episode_results( 880 self, 881 episode_name: str, 882 logs: Tuple, 883 training_time: str = None, 884 ) -> None: 885 """Add results to an episode record. 886 887 Parameters 888 ---------- 889 episode_name : str 890 the name of the episode to update 891 logs : dict 892 a log dictionary from task.train() 893 training_time : str 894 the training time 895 896 """ 897 metrics_log = logs[1] 898 results = {} 899 for key, value in metrics_log["val"].items(): 900 results[("results", key)] = value[-1] 901 if training_time is not None: 902 results[("meta", "training_time")] = training_time 903 for k, v in results.items(): 904 self.data.loc[episode_name, k] = v 905 self._save() 906 907 def get_runs(self, episode_name: str) -> List: 908 """Get a list of runs with this episode name (episodes like `episode_name#0`). 909 910 Parameters 911 ---------- 912 episode_name : str 913 the name of the episode 914 915 Returns 916 ------- 917 runs_list : List 918 a list of string run names 919 920 """ 921 if episode_name is None: 922 return [] 923 index = self.data.index 924 runs_list = [] 925 for name in index: 926 if name.startswith(episode_name): 927 if "::" in name: 928 split = name.split("::") 929 else: 930 split = name.split("#") 931 if split[0] == episode_name: 932 if len(split) > 1 and split[-1].isnumeric() or len(split) == 1: 933 runs_list.append(name) 934 elif name == episode_name: 935 runs_list.append(name) 936 return runs_list 937 938 def _save(self): 939 """Save the dataframe.""" 940 self.data.copy().to_pickle(self.path) 941 942 943class Searches(SavedRuns): 944 """A class that manages operations with search records.""" 945 946 def save_search( 947 self, 948 search_name: str, 949 parameters: Dict, 950 n_trials: int, 951 best_params: Dict, 952 best_value: float, 953 metric: str, 954 search_space: Dict, 955 ) -> None: 956 """Save a new search record. 957 958 Parameters 959 ---------- 960 search_name : str 961 the name of the search to save 962 parameters : dict 963 the task parameters to save 964 n_trials : int 965 the number of trials in the search 966 best_params : dict 967 the best parameters dictionary 968 best_value : float 969 the best valie 970 metric : str 971 the name of the objective metric 972 search_space : dict 973 a dictionary representing the search space; of this general structure: 974 {'group/param_name': ('float/int/float_log/int_log', start, end), 975 'group/param_name': ('categorical', [choices])}, e.g. 976 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 977 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])} 978 979 """ 980 pars = deepcopy(parameters) 981 pars["results"] = {"best_value": best_value, "best_params": best_params} 982 pars["meta"] = { 983 "objective": metric, 984 "n_trials": n_trials, 985 "search_space": search_space, 986 } 987 self.save_episode(search_name, pars, {}) 988 989 def get_best_params_raw(self, search_name: str) -> Dict: 990 """Get the raw dictionary of best parameters found by a search. 991 992 Parameters 993 ---------- 994 search_name : str 995 the name of the search 996 997 Returns 998 ------- 999 best_params : dict 1000 a dictionary of the best parameters where the keys are in '{group}/{name}' format 1001 1002 """ 1003 return self.data.loc[search_name]["results"]["best_params"] 1004 1005 def get_best_params( 1006 self, 1007 search_name: str, 1008 load_parameters: List = None, 1009 round_to_binary: List = None, 1010 ) -> Dict: 1011 """Get the best parameters from a search. 1012 1013 Parameters 1014 ---------- 1015 search_name : str 1016 the name of the search 1017 load_parameters : List, optional 1018 a list of string names of the parameters to load (if not provided all parameters are loaded) 1019 round_to_binary : List, optional 1020 a list of string names of the loaded parameters that should be rounded to the nearest power of two 1021 1022 Returns 1023 ------- 1024 best_params : dict 1025 a dictionary of the best parameters 1026 1027 """ 1028 if round_to_binary is None: 1029 round_to_binary = [] 1030 params = self.data.loc[search_name]["results"]["best_params"] 1031 if load_parameters is not None: 1032 params = {k: v for k, v in params.items() if k in load_parameters} 1033 for par_name in round_to_binary: 1034 if par_name not in params: 1035 continue 1036 if not isinstance(params[par_name], float) and not isinstance( 1037 params[par_name], int 1038 ): 1039 raise TypeError( 1040 f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two" 1041 ) 1042 i = 1 1043 while 2**i < params[par_name]: 1044 i += 1 1045 if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]: 1046 params[par_name] = 2 ** (i - 1) 1047 else: 1048 params[par_name] = 2**i 1049 res = defaultdict(lambda: defaultdict(lambda: {})) 1050 for k, v in params.items(): 1051 big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:]) 1052 if len(small_key.split("/")) == 1: 1053 res[big_key][small_key] = v 1054 else: 1055 group, key = small_key.split("/") 1056 res[big_key][group][key] = v 1057 model = self.data.loc[search_name]["general"]["model_name"] 1058 return res, model 1059 1060 1061class Suggestions(SavedRuns): 1062 """A class that manages operations with suggestion records.""" 1063 1064 def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters): 1065 """Save a new suggestion record.""" 1066 pars = deepcopy(parameters) 1067 pars["meta"] = meta_parameters 1068 super().save_episode(episode_name, pars, behaviors_dict=None) 1069 1070 1071class SavedStores: 1072 """A class that manages operations with saved dataset records.""" 1073 1074 def __init__(self, path): 1075 """Initialize the class. 1076 1077 Parameters 1078 ---------- 1079 path : str 1080 the path to the pickled SavedRuns dataframe 1081 1082 """ 1083 self.path = path 1084 self.data = pd.read_pickle(path) 1085 self.skip_keys = [ 1086 "feature_save_path", 1087 "saved_data_path", 1088 "real_lens", 1089 "recompute_annotation", 1090 ] 1091 1092 def clear(self) -> None: 1093 """Remove all datasets.""" 1094 for dataset_name in self.data.index: 1095 self.remove_dataset(dataset_name) 1096 1097 def dataset_names(self) -> List: 1098 """Get a list of dataset names. 1099 1100 Returns 1101 ------- 1102 dataset_names : List 1103 a list of string dataset names 1104 1105 """ 1106 return list(self.data.index) 1107 1108 def remove(self, names: List) -> None: 1109 """Remove some datasets. 1110 1111 Parameters 1112 ---------- 1113 names : List 1114 a list of string names of the datasets to delete 1115 1116 """ 1117 for dataset_name in names: 1118 if dataset_name in self.data.index: 1119 self.remove_dataset(dataset_name) 1120 1121 def remove_dataset(self, dataset_name: str) -> None: 1122 """Remove a dataset record. 1123 1124 Parameters 1125 ---------- 1126 dataset_name : str 1127 the name of the dataset to remove 1128 1129 """ 1130 if dataset_name in self.data.index: 1131 self.data = self.data.drop(index=dataset_name) 1132 self._save() 1133 1134 def find_name(self, parameters: Dict) -> str: 1135 """Find a record that satisfies the parameters (if it exists). 1136 1137 Parameters 1138 ---------- 1139 parameters : dict 1140 a dictionary of data parameters 1141 1142 Returns 1143 ------- 1144 name : str 1145 the name of a record that has the same parameters (None if it does not exist; the earliest if there are 1146 several) 1147 1148 """ 1149 filter = deepcopy(parameters) 1150 for key, value in parameters.items(): 1151 if value is None or key in self.skip_keys: 1152 filter.pop(key) 1153 elif key not in self.data.columns: 1154 return None 1155 saved_annotation = self.data[ 1156 (self.data[list(filter)] == pd.Series(filter)).all(axis=1) 1157 ] 1158 for i in range(len(saved_annotation)): 1159 ok = True 1160 for key in saved_annotation.columns: 1161 if key in self.skip_keys: 1162 continue 1163 isnull = pd.isnull(saved_annotation.iloc[i][key]) 1164 if not isinstance(isnull, bool): 1165 isnull = False 1166 if key not in filter and not isnull: 1167 ok = False 1168 if ok: 1169 name = saved_annotation.iloc[i].name 1170 return name 1171 return None 1172 1173 def save_store(self, episode_name: str, parameters: Dict) -> None: 1174 """Save a new saved dataset record. 1175 1176 Parameters 1177 ---------- 1178 episode_name : str 1179 the name of the dataset 1180 parameters : dict 1181 a dictionary of data parameters 1182 1183 """ 1184 pars = deepcopy(parameters) 1185 for k, v in parameters.items(): 1186 if k not in self.data.columns: 1187 self.data[k] = np.nan 1188 if self.find_name(pars) is None: 1189 self.data.loc[episode_name] = pars 1190 self._save() 1191 1192 def _save(self): 1193 """Save the dataframe.""" 1194 self.data.to_pickle(self.path) 1195 1196 def check_name_validity(self, store_name: str) -> bool: 1197 """Check if a store name already exists. 1198 1199 Parameters 1200 ---------- 1201 store_name : str 1202 the name to check 1203 1204 Returns 1205 ------- 1206 result : bool 1207 True if the name can be used 1208 1209 """ 1210 if store_name in self.data.index: 1211 return False 1212 else: 1213 return True 1214 1215 1216def _check_str_conversion(params_origin: pd.DataFrame) -> pd.DataFrame: 1217 """Check if the parameters are in string format and convert them to the correct type.""" 1218 params = deepcopy(params_origin) 1219 1220 # Early return if no conversion needed 1221 try: 1222 # Check if conversion is needed by testing a known column 1223 test_value = params[("general", "exclusive")] 1224 if isinstance(params, pd.DataFrame): 1225 # For DataFrame, check the first non-null value 1226 test_value = test_value.dropna().iloc[0] if not test_value.dropna().empty else test_value.iloc[0] 1227 1228 if not isinstance(test_value, str): 1229 return params 1230 except (KeyError, IndexError): 1231 # If the test column doesn't exist or is empty, return as-is 1232 return params 1233 1234 def safe_eval(value: str) -> any: 1235 """Safely evaluate a string value with fallback handling.""" 1236 if not isinstance(value, str): 1237 return value 1238 1239 # Handle special case for odict_keys 1240 if value.startswith("set(odict_keys("): 1241 value = value.replace("set(odict_keys(", "set(").replace("))", ")") 1242 if value.startswith("ordereddict("): 1243 value = value.replace("ordered", "") 1244 try: 1245 result = eval(value) 1246 # Convert floats ending with .0 to integers 1247 if isinstance(result, float) and result.is_integer(): 1248 return int(result) 1249 return result 1250 except (ValueError, SyntaxError, NameError, TypeError): 1251 # If eval fails, keep as string 1252 return value 1253 1254 def convert_stats(value) -> Dict: 1255 """Convert stats values, handling both strings and pandas Series.""" 1256 if isinstance(value, pd.Series): 1257 # Handle pandas Series by applying conversion to each element 1258 return value.apply(lambda x: convert_stats(x) if isinstance(x, str) else x) 1259 elif isinstance(value, str): 1260 # Replace tensor(...) with torch.tensor(..., dtype=torch.float) for proper evaluation 1261 s_clean = re.sub(r'tensor\(\s*(\[.*?\])\s*\)', r'torch.tensor(\1, dtype=torch.float)', value, flags=re.DOTALL) 1262 # Evaluate the cleaned string safely 1263 try: 1264 import torch 1265 result = eval(s_clean, {"torch": torch}) 1266 return result 1267 except Exception as e: 1268 print(f"Error while converting string: {e}") 1269 return value # Return original value instead of None 1270 else: 1271 # Return non-string values as-is 1272 return value 1273 1274 if isinstance(params, pd.DataFrame): 1275 for col in params.columns: 1276 if isinstance(col, tuple) and len(col) == 2: # MultiIndex column 1277 if col[1] == "stats": 1278 # Special handling for stats columns - convert tensor strings 1279 params[col] = params[col].apply(convert_stats) 1280 else: 1281 # Regular string conversion for other columns 1282 string_mask = params[col].apply(lambda x: isinstance(x, str)) 1283 if string_mask.any(): 1284 params.loc[string_mask, col] = params.loc[string_mask, col].apply(safe_eval) 1285 else: 1286 # Handle Series with MultiIndex 1287 for key, value in params.items(): 1288 if isinstance(key, tuple) and len(key) == 2: # MultiIndex 1289 if key[1] == "stats": 1290 params[key] = convert_stats(value) 1291 else: 1292 params[key] = safe_eval(value) 1293 1294 return params
25class Run: 26 """A class that manages operations with a single episode record.""" 27 28 def __init__( 29 self, 30 episode_name: str, 31 project_path: str, 32 meta_path: str = None, 33 params: Dict = None, 34 ): 35 """Initialize the class. 36 37 Parameters 38 ---------- 39 episode_name : str 40 the name of the episode 41 project_path : str 42 the path to the project folder 43 meta_path : str, optional 44 the path to the pickled SavedRuns dataframe 45 params : dict, optional 46 alternative to meta_path: pre-loaded pandas Series of episode parameters 47 48 """ 49 self.name = episode_name 50 self.project_path = project_path 51 if meta_path is not None: 52 try: 53 self.params = pd.read_pickle(meta_path).loc[episode_name] 54 except: 55 raise ValueError(f"The {episode_name} episode does not exist!") 56 elif params is not None: 57 self.params = params 58 else: 59 raise ValueError("Either meta_path or params has to be not None") 60 self.params = self._check_str_conversion() 61 62 def _check_str_conversion(self): 63 """Check if the parameters are in string format and convert them to the correct type.""" 64 return _check_str_conversion(self.params) 65 66 def training_time(self) -> int: 67 """Get the training time in seconds. 68 69 Returns 70 ------- 71 training_time : int 72 the training time in seconds 73 74 """ 75 time_str = self.params["meta"].get("training_time") 76 try: 77 if time_str is None or np.isnan(time_str): 78 return np.nan 79 except TypeError: 80 pass 81 h, m, s = time_str.split(":") 82 seconds = int(h) * 3600 + int(m) * 60 + int(s) 83 return seconds 84 85 def model_file(self, load_epoch: int = None) -> str: 86 """Get a checkpoint file path. 87 88 Parameters 89 ---------- 90 load_epoch : int, optional 91 the epoch to load (the closest checkpoint will be chosen; if not given will be set to last) 92 93 Returns 94 ------- 95 checkpoint_path : str 96 the path to the checkpoint 97 98 """ 99 model_path = correct_path( 100 self.params["training"]["model_save_path"], self.project_path 101 ) 102 if load_epoch is None: 103 model_file = sorted(os.listdir(model_path))[-1] 104 else: 105 model_files = os.listdir(model_path) 106 if len(model_files) == 0: 107 model_file = None 108 else: 109 epochs = [int(file[5:].split(".")[0]) for file in model_files] 110 diffs = [np.abs(epoch - load_epoch) for epoch in epochs] 111 argmin = np.argmin(diffs) 112 model_file = model_files[argmin] 113 model_file = os.path.join(model_path, model_file) 114 return model_file 115 116 def dataset_name(self) -> str: 117 """Get the dataset name. 118 119 Returns 120 ------- 121 dataset_name : str 122 the name of the dataset record 123 124 """ 125 data_path = correct_path( 126 self.params["data"]["feature_save_path"], self.project_path 127 ) 128 dataset_name = os.path.basename(data_path) 129 return dataset_name 130 131 def split_file(self) -> str: 132 """Get the split file. 133 134 Returns 135 ------- 136 split_path : str 137 the path to the split file 138 139 """ 140 return correct_path(self.params["training"]["split_path"], self.project_path) 141 142 def log_file(self) -> str: 143 """Get the log file. 144 145 Returns 146 ------- 147 log_path : str 148 the path to the log file 149 150 """ 151 return correct_path(self.params["training"]["log_file"], self.project_path) 152 153 def split_info(self) -> Dict: 154 """Get the train/test/val split information. 155 156 Returns 157 ------- 158 split_info : dict 159 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values 160 161 """ 162 val_frac = self.params["training"]["val_frac"] 163 test_frac = self.params["training"]["test_frac"] 164 partition_method = self.params["training"]["partition_method"] 165 return { 166 "val_frac": val_frac, 167 "test_frac": test_frac, 168 "partition_method": partition_method, 169 } 170 171 def same_split_info(self, split_info: Dict) -> bool: 172 """Check whether this episode has the same split information. 173 174 Parameters 175 ---------- 176 split_info : dict 177 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode 178 179 Returns 180 ------- 181 result : bool 182 if True, this episode has the same split information 183 184 """ 185 self_split_info = self.split_info() 186 for k in ["val_frac", "test_frac", "partition_method"]: 187 if self_split_info[k] != split_info[k]: 188 return False 189 return True 190 191 def get_metrics(self) -> List: 192 """Get a list of tracked metrics. 193 194 Returns 195 ------- 196 metrics : list 197 a list of tracked metric names 198 199 """ 200 return self.params["general"]["metric_functions"] 201 202 def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray: 203 """Get the metric log. 204 205 Parameters 206 ---------- 207 mode : {'train', 'val'} 208 the mode to get the log from 209 metric_name : str 210 the metric to get the log for (has to be one of the metric computed for this episode during training) 211 212 Returns 213 ------- 214 log : np.ndarray 215 the log of metric values (empty if the metric was not computed during training) 216 217 """ 218 metric_array = [] 219 with open(self.log_file()) as f: 220 for line in f.readlines(): 221 if mode == "train" and line.startswith("[epoch"): 222 line = line.split("]: ")[1] 223 elif mode == "val" and line.startswith("validation"): 224 line = line.split("validation: ")[1] 225 else: 226 continue 227 metrics = line.split(", ") 228 229 metric_ind = np.where( 230 np.array([m.split()[0] for m in metrics]) == metric_name 231 )[0] 232 if len(metric_ind): 233 name, value = metrics[metric_ind[0]].split() 234 metric_array.append(float(value)) 235 else: 236 metric_inds = [ 237 m for m in metrics if m.split()[0].split("_")[0] == metric_name 238 ] 239 if len(metric_inds): 240 beh_metrics_avg = np.mean( 241 [float(m.split()[1]) for m in metric_inds] 242 ) 243 metric_array.append(beh_metrics_avg) 244 245 return np.array(metric_array) 246 247 def get_epoch_list(self, mode) -> List: 248 """Get a list of epoch indices. 249 250 Parameters 251 ---------- 252 mode : {'train', 'val'} 253 the mode to get the epoch list for 254 255 Returns 256 ------- 257 epoch_list : list 258 a list of int epoch indices 259 260 """ 261 epoch_list = [] 262 with open(self.log_file()) as f: 263 for line in f.readlines(): 264 if line.startswith("[epoch"): 265 epoch = int(line[7:].split("]:")[0]) 266 if mode == "train": 267 epoch_list.append(epoch) 268 elif mode == "val": 269 epoch_list.append(epoch) 270 return epoch_list 271 272 def get_metrics(self) -> List: 273 """Get a list of metric names in the episode log. 274 275 Returns 276 ------- 277 metrics : List 278 a list of string metric names 279 280 """ 281 metrics = [] 282 with open(self.log_file()) as f: 283 for line in f.readlines(): 284 if line.startswith("[epoch"): 285 line = line.split("]: ")[1] 286 elif line.startswith("validation"): 287 line = line.split("validation: ")[1] 288 else: 289 continue 290 metric_logs = line.split(", ") 291 for metric in metric_logs: 292 name, _ = metric.split() 293 metrics.append(name) 294 break 295 return metrics 296 297 def unfinished(self) -> bool: 298 """Check whether this episode was interrupted. 299 300 Returns 301 ------- 302 result : bool 303 True if the number of epochs in the log file is smaller than in the parameters 304 305 """ 306 num_epoch_theor = self.params["training"]["num_epochs"] 307 log_file = self.log_file() 308 if not isinstance(log_file, str): 309 return False 310 if not os.path.exists(log_file): 311 return True 312 with open(self.log_file()) as f: 313 num_epoch = 0 314 val = False 315 for line in f.readlines(): 316 num_epoch += 1 317 if num_epoch == 2 and line.startswith("validation"): 318 val = True 319 if val: 320 num_epoch //= 2 321 return num_epoch < num_epoch_theor 322 323 def get_class_ind(self, class_name: str) -> int: 324 """Get the integer label from a class name. 325 326 Parameters 327 ---------- 328 class_name : str 329 the name of the class 330 331 Returns 332 ------- 333 class_ind : int 334 the integer label 335 336 """ 337 behaviors_dict = self.params["meta"]["behaviors_dict"] 338 for k, v in behaviors_dict.items(): 339 if v == class_name: 340 return k 341 raise ValueError( 342 f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})" 343 ) 344 345 def get_behaviors_dict(self) -> Dict: 346 """Get behaviors dictionary in the episode. 347 348 Returns 349 ------- 350 behaviors_dict : dict 351 a dictionary with class indices as keys and labels as values 352 353 """ 354 behavior_dict = self.params["meta"]["behaviors_dict"] 355 if isinstance(behavior_dict, str): 356 behavior_dict = ast.literal_eval(behavior_dict) 357 358 return behavior_dict 359 360 def get_num_classes(self) -> int: 361 """Get number of classes in episode. 362 363 Returns 364 ------- 365 num_classes : int 366 the number of classes 367 368 """ 369 return len(self.params["meta"]["behaviors_dict"])
A class that manages operations with a single episode record.
28 def __init__( 29 self, 30 episode_name: str, 31 project_path: str, 32 meta_path: str = None, 33 params: Dict = None, 34 ): 35 """Initialize the class. 36 37 Parameters 38 ---------- 39 episode_name : str 40 the name of the episode 41 project_path : str 42 the path to the project folder 43 meta_path : str, optional 44 the path to the pickled SavedRuns dataframe 45 params : dict, optional 46 alternative to meta_path: pre-loaded pandas Series of episode parameters 47 48 """ 49 self.name = episode_name 50 self.project_path = project_path 51 if meta_path is not None: 52 try: 53 self.params = pd.read_pickle(meta_path).loc[episode_name] 54 except: 55 raise ValueError(f"The {episode_name} episode does not exist!") 56 elif params is not None: 57 self.params = params 58 else: 59 raise ValueError("Either meta_path or params has to be not None") 60 self.params = self._check_str_conversion()
Initialize the class.
Parameters
episode_name : str the name of the episode project_path : str the path to the project folder meta_path : str, optional the path to the pickled SavedRuns dataframe params : dict, optional alternative to meta_path: pre-loaded pandas Series of episode parameters
66 def training_time(self) -> int: 67 """Get the training time in seconds. 68 69 Returns 70 ------- 71 training_time : int 72 the training time in seconds 73 74 """ 75 time_str = self.params["meta"].get("training_time") 76 try: 77 if time_str is None or np.isnan(time_str): 78 return np.nan 79 except TypeError: 80 pass 81 h, m, s = time_str.split(":") 82 seconds = int(h) * 3600 + int(m) * 60 + int(s) 83 return seconds
Get the training time in seconds.
Returns
training_time : int the training time in seconds
85 def model_file(self, load_epoch: int = None) -> str: 86 """Get a checkpoint file path. 87 88 Parameters 89 ---------- 90 load_epoch : int, optional 91 the epoch to load (the closest checkpoint will be chosen; if not given will be set to last) 92 93 Returns 94 ------- 95 checkpoint_path : str 96 the path to the checkpoint 97 98 """ 99 model_path = correct_path( 100 self.params["training"]["model_save_path"], self.project_path 101 ) 102 if load_epoch is None: 103 model_file = sorted(os.listdir(model_path))[-1] 104 else: 105 model_files = os.listdir(model_path) 106 if len(model_files) == 0: 107 model_file = None 108 else: 109 epochs = [int(file[5:].split(".")[0]) for file in model_files] 110 diffs = [np.abs(epoch - load_epoch) for epoch in epochs] 111 argmin = np.argmin(diffs) 112 model_file = model_files[argmin] 113 model_file = os.path.join(model_path, model_file) 114 return model_file
Get a checkpoint file path.
Parameters
load_epoch : int, optional the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)
Returns
checkpoint_path : str the path to the checkpoint
116 def dataset_name(self) -> str: 117 """Get the dataset name. 118 119 Returns 120 ------- 121 dataset_name : str 122 the name of the dataset record 123 124 """ 125 data_path = correct_path( 126 self.params["data"]["feature_save_path"], self.project_path 127 ) 128 dataset_name = os.path.basename(data_path) 129 return dataset_name
Get the dataset name.
Returns
dataset_name : str the name of the dataset record
131 def split_file(self) -> str: 132 """Get the split file. 133 134 Returns 135 ------- 136 split_path : str 137 the path to the split file 138 139 """ 140 return correct_path(self.params["training"]["split_path"], self.project_path)
Get the split file.
Returns
split_path : str the path to the split file
142 def log_file(self) -> str: 143 """Get the log file. 144 145 Returns 146 ------- 147 log_path : str 148 the path to the log file 149 150 """ 151 return correct_path(self.params["training"]["log_file"], self.project_path)
Get the log file.
Returns
log_path : str the path to the log file
153 def split_info(self) -> Dict: 154 """Get the train/test/val split information. 155 156 Returns 157 ------- 158 split_info : dict 159 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values 160 161 """ 162 val_frac = self.params["training"]["val_frac"] 163 test_frac = self.params["training"]["test_frac"] 164 partition_method = self.params["training"]["partition_method"] 165 return { 166 "val_frac": val_frac, 167 "test_frac": test_frac, 168 "partition_method": partition_method, 169 }
Get the train/test/val split information.
Returns
split_info : dict a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values
171 def same_split_info(self, split_info: Dict) -> bool: 172 """Check whether this episode has the same split information. 173 174 Parameters 175 ---------- 176 split_info : dict 177 a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode 178 179 Returns 180 ------- 181 result : bool 182 if True, this episode has the same split information 183 184 """ 185 self_split_info = self.split_info() 186 for k in ["val_frac", "test_frac", "partition_method"]: 187 if self_split_info[k] != split_info[k]: 188 return False 189 return True
Check whether this episode has the same split information.
Parameters
split_info : dict a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode
Returns
result : bool if True, this episode has the same split information
272 def get_metrics(self) -> List: 273 """Get a list of metric names in the episode log. 274 275 Returns 276 ------- 277 metrics : List 278 a list of string metric names 279 280 """ 281 metrics = [] 282 with open(self.log_file()) as f: 283 for line in f.readlines(): 284 if line.startswith("[epoch"): 285 line = line.split("]: ")[1] 286 elif line.startswith("validation"): 287 line = line.split("validation: ")[1] 288 else: 289 continue 290 metric_logs = line.split(", ") 291 for metric in metric_logs: 292 name, _ = metric.split() 293 metrics.append(name) 294 break 295 return metrics
Get a list of metric names in the episode log.
Returns
metrics : List a list of string metric names
202 def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray: 203 """Get the metric log. 204 205 Parameters 206 ---------- 207 mode : {'train', 'val'} 208 the mode to get the log from 209 metric_name : str 210 the metric to get the log for (has to be one of the metric computed for this episode during training) 211 212 Returns 213 ------- 214 log : np.ndarray 215 the log of metric values (empty if the metric was not computed during training) 216 217 """ 218 metric_array = [] 219 with open(self.log_file()) as f: 220 for line in f.readlines(): 221 if mode == "train" and line.startswith("[epoch"): 222 line = line.split("]: ")[1] 223 elif mode == "val" and line.startswith("validation"): 224 line = line.split("validation: ")[1] 225 else: 226 continue 227 metrics = line.split(", ") 228 229 metric_ind = np.where( 230 np.array([m.split()[0] for m in metrics]) == metric_name 231 )[0] 232 if len(metric_ind): 233 name, value = metrics[metric_ind[0]].split() 234 metric_array.append(float(value)) 235 else: 236 metric_inds = [ 237 m for m in metrics if m.split()[0].split("_")[0] == metric_name 238 ] 239 if len(metric_inds): 240 beh_metrics_avg = np.mean( 241 [float(m.split()[1]) for m in metric_inds] 242 ) 243 metric_array.append(beh_metrics_avg) 244 245 return np.array(metric_array)
Get the metric log.
Parameters
mode : {'train', 'val'} the mode to get the log from metric_name : str the metric to get the log for (has to be one of the metric computed for this episode during training)
Returns
log : np.ndarray the log of metric values (empty if the metric was not computed during training)
247 def get_epoch_list(self, mode) -> List: 248 """Get a list of epoch indices. 249 250 Parameters 251 ---------- 252 mode : {'train', 'val'} 253 the mode to get the epoch list for 254 255 Returns 256 ------- 257 epoch_list : list 258 a list of int epoch indices 259 260 """ 261 epoch_list = [] 262 with open(self.log_file()) as f: 263 for line in f.readlines(): 264 if line.startswith("[epoch"): 265 epoch = int(line[7:].split("]:")[0]) 266 if mode == "train": 267 epoch_list.append(epoch) 268 elif mode == "val": 269 epoch_list.append(epoch) 270 return epoch_list
Get a list of epoch indices.
Parameters
mode : {'train', 'val'} the mode to get the epoch list for
Returns
epoch_list : list a list of int epoch indices
297 def unfinished(self) -> bool: 298 """Check whether this episode was interrupted. 299 300 Returns 301 ------- 302 result : bool 303 True if the number of epochs in the log file is smaller than in the parameters 304 305 """ 306 num_epoch_theor = self.params["training"]["num_epochs"] 307 log_file = self.log_file() 308 if not isinstance(log_file, str): 309 return False 310 if not os.path.exists(log_file): 311 return True 312 with open(self.log_file()) as f: 313 num_epoch = 0 314 val = False 315 for line in f.readlines(): 316 num_epoch += 1 317 if num_epoch == 2 and line.startswith("validation"): 318 val = True 319 if val: 320 num_epoch //= 2 321 return num_epoch < num_epoch_theor
Check whether this episode was interrupted.
Returns
result : bool True if the number of epochs in the log file is smaller than in the parameters
323 def get_class_ind(self, class_name: str) -> int: 324 """Get the integer label from a class name. 325 326 Parameters 327 ---------- 328 class_name : str 329 the name of the class 330 331 Returns 332 ------- 333 class_ind : int 334 the integer label 335 336 """ 337 behaviors_dict = self.params["meta"]["behaviors_dict"] 338 for k, v in behaviors_dict.items(): 339 if v == class_name: 340 return k 341 raise ValueError( 342 f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})" 343 )
Get the integer label from a class name.
Parameters
class_name : str the name of the class
Returns
class_ind : int the integer label
345 def get_behaviors_dict(self) -> Dict: 346 """Get behaviors dictionary in the episode. 347 348 Returns 349 ------- 350 behaviors_dict : dict 351 a dictionary with class indices as keys and labels as values 352 353 """ 354 behavior_dict = self.params["meta"]["behaviors_dict"] 355 if isinstance(behavior_dict, str): 356 behavior_dict = ast.literal_eval(behavior_dict) 357 358 return behavior_dict
Get behaviors dictionary in the episode.
Returns
behaviors_dict : dict a dictionary with class indices as keys and labels as values
360 def get_num_classes(self) -> int: 361 """Get number of classes in episode. 362 363 Returns 364 ------- 365 num_classes : int 366 the number of classes 367 368 """ 369 return len(self.params["meta"]["behaviors_dict"])
Get number of classes in episode.
Returns
num_classes : int the number of classes
372class DecisionThresholds: 373 """A class that saves and looks up tuned decision thresholds.""" 374 375 def __init__(self, path: str) -> None: 376 """Initialize the class. 377 378 Parameters 379 ---------- 380 path : str 381 the path to the pickled SavedRuns dataframe 382 383 """ 384 self.path = path 385 self.data = pd.read_pickle(path) 386 387 def save_thresholds( 388 self, 389 episode_names: List, 390 epochs: List, 391 metric_name: str, 392 metric_parameters: Dict, 393 thresholds: List, 394 ) -> None: 395 """Add a new record. 396 397 Parameters 398 ---------- 399 episode_names : list 400 the names of the episodes 401 epochs : int 402 the epoch index list 403 metric_name : str 404 the name of the metric the thresholds were tuned on 405 metric_parameters : dict 406 the metric parameter dictionary 407 thresholds : list 408 a list of float decision thresholds 409 410 """ 411 episodes = set(zip(episode_names, epochs)) 412 for key in ["average", "threshold_value", "ignored_classes"]: 413 if key in metric_parameters: 414 metric_parameters.pop(key) 415 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 416 parameters["thresholds"] = thresholds 417 parameters["episodes"] = episodes 418 pars = {k: [v] for k, v in parameters.items()} 419 self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0) 420 self._save() 421 422 def find_thresholds( 423 self, 424 episode_names: List, 425 epochs: List, 426 metric_name: str, 427 metric_parameters: Dict, 428 ) -> Union[List, None]: 429 """Find a record. 430 431 Parameters 432 ---------- 433 episode_names : list 434 the names of the episodes 435 epochs : list 436 the epoch index list 437 metric_name : str 438 the name of the metric the thresholds were tuned on 439 metric_parameters : dict 440 the metric parameter dictionary 441 442 Returns 443 ------- 444 thresholds : list 445 a list of float decision thresholds 446 447 """ 448 episodes = set(zip(episode_names, epochs)) 449 for key in ["average", "threshold_value", "ignored_classes"]: 450 if key in metric_parameters: 451 metric_parameters.pop(key) 452 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 453 parameters["episodes"] = episodes 454 filter = deepcopy(parameters) 455 for key, value in parameters.items(): 456 if value is None: 457 filter.pop(key) 458 elif key not in self.data.columns: 459 return None 460 data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)] 461 if len(data) > 0: 462 thresholds = data.iloc[0]["thresholds"] 463 return thresholds 464 else: 465 return None 466 467 def _save(self) -> None: 468 """Save the records.""" 469 self.data.copy().to_pickle(self.path)
A class that saves and looks up tuned decision thresholds.
375 def __init__(self, path: str) -> None: 376 """Initialize the class. 377 378 Parameters 379 ---------- 380 path : str 381 the path to the pickled SavedRuns dataframe 382 383 """ 384 self.path = path 385 self.data = pd.read_pickle(path)
Initialize the class.
Parameters
path : str the path to the pickled SavedRuns dataframe
387 def save_thresholds( 388 self, 389 episode_names: List, 390 epochs: List, 391 metric_name: str, 392 metric_parameters: Dict, 393 thresholds: List, 394 ) -> None: 395 """Add a new record. 396 397 Parameters 398 ---------- 399 episode_names : list 400 the names of the episodes 401 epochs : int 402 the epoch index list 403 metric_name : str 404 the name of the metric the thresholds were tuned on 405 metric_parameters : dict 406 the metric parameter dictionary 407 thresholds : list 408 a list of float decision thresholds 409 410 """ 411 episodes = set(zip(episode_names, epochs)) 412 for key in ["average", "threshold_value", "ignored_classes"]: 413 if key in metric_parameters: 414 metric_parameters.pop(key) 415 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 416 parameters["thresholds"] = thresholds 417 parameters["episodes"] = episodes 418 pars = {k: [v] for k, v in parameters.items()} 419 self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0) 420 self._save()
Add a new record.
Parameters
episode_names : list the names of the episodes epochs : int the epoch index list metric_name : str the name of the metric the thresholds were tuned on metric_parameters : dict the metric parameter dictionary thresholds : list a list of float decision thresholds
422 def find_thresholds( 423 self, 424 episode_names: List, 425 epochs: List, 426 metric_name: str, 427 metric_parameters: Dict, 428 ) -> Union[List, None]: 429 """Find a record. 430 431 Parameters 432 ---------- 433 episode_names : list 434 the names of the episodes 435 epochs : list 436 the epoch index list 437 metric_name : str 438 the name of the metric the thresholds were tuned on 439 metric_parameters : dict 440 the metric parameter dictionary 441 442 Returns 443 ------- 444 thresholds : list 445 a list of float decision thresholds 446 447 """ 448 episodes = set(zip(episode_names, epochs)) 449 for key in ["average", "threshold_value", "ignored_classes"]: 450 if key in metric_parameters: 451 metric_parameters.pop(key) 452 parameters = {(metric_name, k): v for k, v in metric_parameters.items()} 453 parameters["episodes"] = episodes 454 filter = deepcopy(parameters) 455 for key, value in parameters.items(): 456 if value is None: 457 filter.pop(key) 458 elif key not in self.data.columns: 459 return None 460 data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)] 461 if len(data) > 0: 462 thresholds = data.iloc[0]["thresholds"] 463 return thresholds 464 else: 465 return None
Find a record.
Parameters
episode_names : list the names of the episodes epochs : list the epoch index list metric_name : str the name of the metric the thresholds were tuned on metric_parameters : dict the metric parameter dictionary
Returns
thresholds : list a list of float decision thresholds
472class SavedRuns: 473 """A class that manages operations with all episode (or prediction) records.""" 474 475 def __init__(self, path: str, project_path: str) -> None: 476 """Initialize the class. 477 478 Parameters 479 ---------- 480 path : str 481 the path to the pickled SavedRuns dataframe 482 project_path : str 483 the path to the project folder 484 485 """ 486 self.path = path 487 self.project_path = project_path 488 self.data = pd.read_pickle(path) 489 self.data = _check_str_conversion(self.data) 490 491 def update( 492 self, 493 data: pd.DataFrame, 494 data_path: str, 495 annotation_path: str, 496 name_map: Dict = None, 497 force: bool = False, 498 ) -> None: 499 """Update with new data. 500 501 Parameters 502 ---------- 503 data : pd.DataFrame 504 the new dataframe 505 data_path : str 506 the new data path 507 annotation_path : str 508 the new annotation path 509 name_map : dict, optional 510 the name change dictionary; keys are old episode names and values are new episode names 511 force : bool, default False 512 replace existing episodes if `True` 513 514 """ 515 if name_map is None: 516 name_map = {} 517 data = data.rename(index=name_map) 518 for episode in data.index: 519 new_model = os.path.join(self.project_path, "results", "model", episode) 520 data.loc[episode, ("training", "model_save_path")] = new_model 521 new_log = os.path.join( 522 self.project_path, "results", "logs", f"{episode}.txt" 523 ) 524 data.loc[episode, ("training", "log_file")] = new_log 525 old_split = data.loc[episode, ("training", "split_path")] 526 if old_split is None: 527 new_split = None 528 else: 529 new_split = os.path.join( 530 self.project_path, "results", "splits", os.path.basename(old_split) 531 ) 532 data.loc[episode, ("training", "split_path")] = new_split 533 data.loc[episode, ("data", "data_path")] = data_path 534 data.loc[episode, ("data", "annotation_path")] = annotation_path 535 if episode in self.data.index: 536 if force: 537 self.data = self.data.drop(index=[episode]) 538 else: 539 raise RuntimeError(f"The {episode} episode name is already taken!") 540 self.data = pd.concat([self.data, data]) 541 self._save() 542 543 def get_subset(self, episode_names: List) -> pd.DataFrame: 544 """Get a subset of the raw metadata. 545 546 Parameters 547 ---------- 548 episode_names : list 549 a list of the episodes to include 550 551 Returns 552 ------- 553 subset : pd.DataFrame 554 the subset of the raw metadata 555 556 """ 557 for episode in episode_names: 558 if episode not in self.data.index: 559 raise ValueError( 560 f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records" 561 ) 562 return self.data.loc[episode_names] 563 564 def get_saved_data_path(self, episode_name: str) -> str: 565 """Get the `saved_data_path` parameter for the episode. 566 567 Parameters 568 ---------- 569 episode_name : str 570 the name of the episode 571 572 Returns 573 ------- 574 saved_data_path : str 575 the saved data path 576 577 """ 578 return self.data.loc[episode_name]["data"]["saved_data_path"] 579 580 def check_name_validity(self, episode_name: str) -> bool: 581 """Check if an episode name already exists. 582 583 Parameters 584 ---------- 585 episode_name : str 586 the name to check 587 588 Returns 589 ------- 590 result : bool 591 True if the name can be used 592 593 """ 594 if episode_name in self.data.index: 595 return False 596 else: 597 return True 598 599 def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None: 600 """Update meta data with evaluation results. 601 602 Parameters 603 ---------- 604 episode_name : str 605 the name of the episode to update 606 metrics : dict 607 a dictionary of the metrics 608 609 """ 610 for key, value in metrics.items(): 611 self.data.loc[episode_name, ("results", key)] = value 612 self._save() 613 614 def save_episode( 615 self, 616 episode_name: str, 617 parameters: Dict, 618 behaviors_dict: Dict, 619 suppress_validation: bool = False, 620 training_time: str = None, 621 ) -> None: 622 """Save a new run record. 623 624 Parameters 625 ---------- 626 episode_name : str 627 the name of the episode 628 parameters : dict 629 the parameters to save 630 behaviors_dict : dict 631 the dictionary of behaviors (keys are indices, values are names) 632 suppress_validation : bool, optional False 633 if True, existing episode with the same name will be overwritten 634 training_time : str, optional 635 the training time in '%H:%M:%S' format 636 637 """ 638 if not suppress_validation and episode_name in self.data.index: 639 raise ValueError(f"Episode {episode_name} already exists!") 640 pars = deepcopy(parameters) 641 if "meta" not in pars: 642 pars["meta"] = { 643 "time": strftime("%Y-%m-%d %H:%M:%S", localtime()), 644 "behaviors_dict": behaviors_dict, 645 } 646 else: 647 pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime()) 648 pars["meta"]["behaviors_dict"] = behaviors_dict 649 if training_time is not None: 650 pars["meta"]["training_time"] = training_time 651 if len(parameters.keys()) > 1: 652 pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {}) 653 for metric_name in pars["general"]["metric_functions"]: 654 pars[metric_name] = pars["metrics"].get(metric_name, {}) 655 if pars["general"].get("ssl", None) is not None: 656 for ssl_name in pars["general"]["ssl"]: 657 pars[ssl_name] = pars["ssl"].get(ssl_name, {}) 658 for group_name in ["metrics", "ssl"]: 659 if group_name in pars: 660 pars.pop(group_name) 661 data = { 662 (big_key, small_key): value 663 for big_key, big_value in pars.items() 664 for small_key, value in big_value.items() 665 } 666 list_keys = [] 667 with warnings.catch_warnings(): 668 warnings.filterwarnings("ignore", message="DataFrame is highly fragmented") 669 for k, v in data.items(): 670 if k not in self.data.columns: 671 self.data[k] = np.nan 672 if isinstance(v, list) and not isinstance(v, str): 673 list_keys.append(k) 674 for k in list_keys: 675 self.data[k] = self.data[k].astype(object) 676 self.data.loc[episode_name] = data 677 self._save() 678 679 def load_parameters(self, episode_name: str) -> Dict: 680 """Load the task parameters from a record. 681 682 Parameters 683 ---------- 684 episode_name : str 685 the name of the episode to load 686 687 Returns 688 ------- 689 parameters : dict 690 the loaded task parameters 691 692 """ 693 parameters = defaultdict(lambda: defaultdict(lambda: {})) 694 episode = self.data.loc[episode_name].dropna().to_dict() 695 keys = ["data", "augmentations", "general", "training", "model", "features"] 696 for key in episode: 697 big_key, small_key = key 698 if big_key in keys: 699 parameters[big_key][small_key] = episode[key] 700 # parameters = {k: dict(v) for k, v in parameters.items()} 701 ssl_keys = parameters["general"].get("ssl", None) 702 metric_keys = parameters["general"].get("metric_functions", None) 703 loss_key = parameters["general"]["loss_function"] 704 if ssl_keys is None: 705 ssl_keys = [] 706 if metric_keys is None: 707 metric_keys = [] 708 for key in episode: 709 big_key, small_key = key 710 if big_key in ssl_keys: 711 parameters["ssl"][big_key][small_key] = episode[key] 712 elif big_key in metric_keys: 713 parameters["metrics"][big_key][small_key] = episode[key] 714 elif big_key == "losses": 715 parameters["losses"][loss_key][small_key] = episode[key] 716 parameters = {k: dict(v) for k, v in parameters.items()} 717 parameters["general"]["num_classes"] = Run( 718 episode_name, self.project_path, params=self.data.loc[episode_name] 719 ).get_num_classes() 720 return parameters 721 722 def get_active_datasets(self) -> List: 723 """Get a list of names of datasets that are used by unfinished episodes. 724 725 Returns 726 ------- 727 active_datasets : list 728 a list of dataset names used by unfinished episodes 729 730 """ 731 active_datasets = [] 732 for episode_name in self.unfinished_episodes(): 733 run = Run( 734 episode_name, self.project_path, params=self.data.loc[episode_name] 735 ) 736 active_datasets.append(run.dataset_name()) 737 return active_datasets 738 739 def list_episodes( 740 self, 741 episode_names: List = None, 742 value_filter: str = "", 743 display_parameters: List = None, 744 ) -> pd.DataFrame: 745 """Get a filtered pandas dataframe with episode metadata. 746 747 Parameters 748 ---------- 749 episode_names : List 750 a list of strings of episode names 751 value_filter : str 752 a string of filters to apply of this general structure: 753 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 754 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' 755 display_parameters : List 756 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 757 758 Returns 759 ------- 760 pandas.DataFrame 761 the filtered dataframe 762 763 """ 764 if episode_names is not None: 765 data = deepcopy(self.data.loc[episode_names]) 766 else: 767 data = deepcopy(self.data) 768 if len(data) == 0: 769 return pd.DataFrame() 770 try: 771 filters = value_filter.split(",") 772 if filters == [""]: 773 filters = [] 774 for f in filters: 775 par_name, condition = f.split("::") 776 group_name, par_name = par_name.split("/") 777 sign, value = condition[0], condition[1:] 778 if value[0] == "=": 779 sign += "=" 780 value = value[1:] 781 try: 782 value = float(value) 783 except: 784 if value == "True": 785 value = True 786 elif value == "False": 787 value = False 788 elif value == "None": 789 value = None 790 if value is None: 791 if sign == "=": 792 data = data[data[group_name][par_name].isna()] 793 elif sign == "!=": 794 data = data[~data[group_name][par_name].isna()] 795 elif sign == ">": 796 data = data[data[group_name][par_name] > value] 797 elif sign == ">=": 798 data = data[data[group_name][par_name] >= value] 799 elif sign == "<": 800 data = data[data[group_name][par_name] < value] 801 elif sign == "<=": 802 data = data[data[group_name][par_name] <= value] 803 elif sign == "=": 804 data = data[data[group_name][par_name] == value] 805 elif sign == "!=": 806 data = data[data[group_name][par_name] != value] 807 else: 808 raise ValueError( 809 "Please use one of the signs: [>, <, >=, <=, =, !=]" 810 ) 811 except ValueError: 812 raise ValueError( 813 f"The {value_filter} filter is not valid, please use the following format:" 814 f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', " 815 f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'" 816 ) 817 if display_parameters is not None: 818 if type(display_parameters[0]) is str: 819 display_parameters = [ 820 (x.split("/")[0], x.split("/")[1]) for x in display_parameters 821 ] 822 display_parameters = [x for x in display_parameters if x in data.columns] 823 data = data[display_parameters] 824 return data 825 826 def rename_episode(self, episode_name, new_episode_name): 827 """Rename an episode. 828 829 Parameters 830 ---------- 831 episode_name : str 832 the name of the episode to rename 833 new_episode_name : str 834 the new name of the episode 835 836 """ 837 if episode_name in self.data.index and new_episode_name not in self.data.index: 838 self.data.loc[new_episode_name] = self.data.loc[episode_name] 839 model_path = self.data.loc[new_episode_name, ("training", "model_path")] 840 self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join( 841 os.path.dirname(model_path), new_episode_name 842 ) 843 log_path = self.data.loc[new_episode_name, ("training", "log_file")] 844 self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join( 845 os.path.dirname(log_path), f"{new_episode_name}.txt" 846 ) 847 self.data = self.data.drop(index=episode_name) 848 self._save() 849 else: 850 raise ValueError("The names are wrong") 851 852 def remove_episode(self, episode_name: str) -> None: 853 """Remove all model, logs and metafile records related to an episode. 854 855 Parameters 856 ---------- 857 episode_name : str 858 the name of the episode to remove 859 860 """ 861 if episode_name in self.data.index: 862 self.data = self.data.drop(index=episode_name) 863 self._save() 864 865 def unfinished_episodes(self) -> List: 866 """Get a list of unfinished episodes (currently running or interrupted). 867 868 Returns 869 ------- 870 interrupted_episodes: List 871 a list of string names of unfinished episodes in the records 872 873 """ 874 unfinished = [] 875 for name, params in self.data.iterrows(): 876 if Run(name, project_path=self.project_path, params=params).unfinished(): 877 unfinished.append(name) 878 return unfinished 879 880 def update_episode_results( 881 self, 882 episode_name: str, 883 logs: Tuple, 884 training_time: str = None, 885 ) -> None: 886 """Add results to an episode record. 887 888 Parameters 889 ---------- 890 episode_name : str 891 the name of the episode to update 892 logs : dict 893 a log dictionary from task.train() 894 training_time : str 895 the training time 896 897 """ 898 metrics_log = logs[1] 899 results = {} 900 for key, value in metrics_log["val"].items(): 901 results[("results", key)] = value[-1] 902 if training_time is not None: 903 results[("meta", "training_time")] = training_time 904 for k, v in results.items(): 905 self.data.loc[episode_name, k] = v 906 self._save() 907 908 def get_runs(self, episode_name: str) -> List: 909 """Get a list of runs with this episode name (episodes like `episode_name#0`). 910 911 Parameters 912 ---------- 913 episode_name : str 914 the name of the episode 915 916 Returns 917 ------- 918 runs_list : List 919 a list of string run names 920 921 """ 922 if episode_name is None: 923 return [] 924 index = self.data.index 925 runs_list = [] 926 for name in index: 927 if name.startswith(episode_name): 928 if "::" in name: 929 split = name.split("::") 930 else: 931 split = name.split("#") 932 if split[0] == episode_name: 933 if len(split) > 1 and split[-1].isnumeric() or len(split) == 1: 934 runs_list.append(name) 935 elif name == episode_name: 936 runs_list.append(name) 937 return runs_list 938 939 def _save(self): 940 """Save the dataframe.""" 941 self.data.copy().to_pickle(self.path)
A class that manages operations with all episode (or prediction) records.
475 def __init__(self, path: str, project_path: str) -> None: 476 """Initialize the class. 477 478 Parameters 479 ---------- 480 path : str 481 the path to the pickled SavedRuns dataframe 482 project_path : str 483 the path to the project folder 484 485 """ 486 self.path = path 487 self.project_path = project_path 488 self.data = pd.read_pickle(path) 489 self.data = _check_str_conversion(self.data)
Initialize the class.
Parameters
path : str the path to the pickled SavedRuns dataframe project_path : str the path to the project folder
491 def update( 492 self, 493 data: pd.DataFrame, 494 data_path: str, 495 annotation_path: str, 496 name_map: Dict = None, 497 force: bool = False, 498 ) -> None: 499 """Update with new data. 500 501 Parameters 502 ---------- 503 data : pd.DataFrame 504 the new dataframe 505 data_path : str 506 the new data path 507 annotation_path : str 508 the new annotation path 509 name_map : dict, optional 510 the name change dictionary; keys are old episode names and values are new episode names 511 force : bool, default False 512 replace existing episodes if `True` 513 514 """ 515 if name_map is None: 516 name_map = {} 517 data = data.rename(index=name_map) 518 for episode in data.index: 519 new_model = os.path.join(self.project_path, "results", "model", episode) 520 data.loc[episode, ("training", "model_save_path")] = new_model 521 new_log = os.path.join( 522 self.project_path, "results", "logs", f"{episode}.txt" 523 ) 524 data.loc[episode, ("training", "log_file")] = new_log 525 old_split = data.loc[episode, ("training", "split_path")] 526 if old_split is None: 527 new_split = None 528 else: 529 new_split = os.path.join( 530 self.project_path, "results", "splits", os.path.basename(old_split) 531 ) 532 data.loc[episode, ("training", "split_path")] = new_split 533 data.loc[episode, ("data", "data_path")] = data_path 534 data.loc[episode, ("data", "annotation_path")] = annotation_path 535 if episode in self.data.index: 536 if force: 537 self.data = self.data.drop(index=[episode]) 538 else: 539 raise RuntimeError(f"The {episode} episode name is already taken!") 540 self.data = pd.concat([self.data, data]) 541 self._save()
Update with new data.
Parameters
data : pd.DataFrame
the new dataframe
data_path : str
the new data path
annotation_path : str
the new annotation path
name_map : dict, optional
the name change dictionary; keys are old episode names and values are new episode names
force : bool, default False
replace existing episodes if True
543 def get_subset(self, episode_names: List) -> pd.DataFrame: 544 """Get a subset of the raw metadata. 545 546 Parameters 547 ---------- 548 episode_names : list 549 a list of the episodes to include 550 551 Returns 552 ------- 553 subset : pd.DataFrame 554 the subset of the raw metadata 555 556 """ 557 for episode in episode_names: 558 if episode not in self.data.index: 559 raise ValueError( 560 f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records" 561 ) 562 return self.data.loc[episode_names]
Get a subset of the raw metadata.
Parameters
episode_names : list a list of the episodes to include
Returns
subset : pd.DataFrame the subset of the raw metadata
564 def get_saved_data_path(self, episode_name: str) -> str: 565 """Get the `saved_data_path` parameter for the episode. 566 567 Parameters 568 ---------- 569 episode_name : str 570 the name of the episode 571 572 Returns 573 ------- 574 saved_data_path : str 575 the saved data path 576 577 """ 578 return self.data.loc[episode_name]["data"]["saved_data_path"]
Get the saved_data_path parameter for the episode.
Parameters
episode_name : str the name of the episode
Returns
saved_data_path : str the saved data path
580 def check_name_validity(self, episode_name: str) -> bool: 581 """Check if an episode name already exists. 582 583 Parameters 584 ---------- 585 episode_name : str 586 the name to check 587 588 Returns 589 ------- 590 result : bool 591 True if the name can be used 592 593 """ 594 if episode_name in self.data.index: 595 return False 596 else: 597 return True
Check if an episode name already exists.
Parameters
episode_name : str the name to check
Returns
result : bool True if the name can be used
599 def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None: 600 """Update meta data with evaluation results. 601 602 Parameters 603 ---------- 604 episode_name : str 605 the name of the episode to update 606 metrics : dict 607 a dictionary of the metrics 608 609 """ 610 for key, value in metrics.items(): 611 self.data.loc[episode_name, ("results", key)] = value 612 self._save()
Update meta data with evaluation results.
Parameters
episode_name : str the name of the episode to update metrics : dict a dictionary of the metrics
614 def save_episode( 615 self, 616 episode_name: str, 617 parameters: Dict, 618 behaviors_dict: Dict, 619 suppress_validation: bool = False, 620 training_time: str = None, 621 ) -> None: 622 """Save a new run record. 623 624 Parameters 625 ---------- 626 episode_name : str 627 the name of the episode 628 parameters : dict 629 the parameters to save 630 behaviors_dict : dict 631 the dictionary of behaviors (keys are indices, values are names) 632 suppress_validation : bool, optional False 633 if True, existing episode with the same name will be overwritten 634 training_time : str, optional 635 the training time in '%H:%M:%S' format 636 637 """ 638 if not suppress_validation and episode_name in self.data.index: 639 raise ValueError(f"Episode {episode_name} already exists!") 640 pars = deepcopy(parameters) 641 if "meta" not in pars: 642 pars["meta"] = { 643 "time": strftime("%Y-%m-%d %H:%M:%S", localtime()), 644 "behaviors_dict": behaviors_dict, 645 } 646 else: 647 pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime()) 648 pars["meta"]["behaviors_dict"] = behaviors_dict 649 if training_time is not None: 650 pars["meta"]["training_time"] = training_time 651 if len(parameters.keys()) > 1: 652 pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {}) 653 for metric_name in pars["general"]["metric_functions"]: 654 pars[metric_name] = pars["metrics"].get(metric_name, {}) 655 if pars["general"].get("ssl", None) is not None: 656 for ssl_name in pars["general"]["ssl"]: 657 pars[ssl_name] = pars["ssl"].get(ssl_name, {}) 658 for group_name in ["metrics", "ssl"]: 659 if group_name in pars: 660 pars.pop(group_name) 661 data = { 662 (big_key, small_key): value 663 for big_key, big_value in pars.items() 664 for small_key, value in big_value.items() 665 } 666 list_keys = [] 667 with warnings.catch_warnings(): 668 warnings.filterwarnings("ignore", message="DataFrame is highly fragmented") 669 for k, v in data.items(): 670 if k not in self.data.columns: 671 self.data[k] = np.nan 672 if isinstance(v, list) and not isinstance(v, str): 673 list_keys.append(k) 674 for k in list_keys: 675 self.data[k] = self.data[k].astype(object) 676 self.data.loc[episode_name] = data 677 self._save()
Save a new run record.
Parameters
episode_name : str the name of the episode parameters : dict the parameters to save behaviors_dict : dict the dictionary of behaviors (keys are indices, values are names) suppress_validation : bool, optional False if True, existing episode with the same name will be overwritten training_time : str, optional the training time in '%H:%M:%S' format
679 def load_parameters(self, episode_name: str) -> Dict: 680 """Load the task parameters from a record. 681 682 Parameters 683 ---------- 684 episode_name : str 685 the name of the episode to load 686 687 Returns 688 ------- 689 parameters : dict 690 the loaded task parameters 691 692 """ 693 parameters = defaultdict(lambda: defaultdict(lambda: {})) 694 episode = self.data.loc[episode_name].dropna().to_dict() 695 keys = ["data", "augmentations", "general", "training", "model", "features"] 696 for key in episode: 697 big_key, small_key = key 698 if big_key in keys: 699 parameters[big_key][small_key] = episode[key] 700 # parameters = {k: dict(v) for k, v in parameters.items()} 701 ssl_keys = parameters["general"].get("ssl", None) 702 metric_keys = parameters["general"].get("metric_functions", None) 703 loss_key = parameters["general"]["loss_function"] 704 if ssl_keys is None: 705 ssl_keys = [] 706 if metric_keys is None: 707 metric_keys = [] 708 for key in episode: 709 big_key, small_key = key 710 if big_key in ssl_keys: 711 parameters["ssl"][big_key][small_key] = episode[key] 712 elif big_key in metric_keys: 713 parameters["metrics"][big_key][small_key] = episode[key] 714 elif big_key == "losses": 715 parameters["losses"][loss_key][small_key] = episode[key] 716 parameters = {k: dict(v) for k, v in parameters.items()} 717 parameters["general"]["num_classes"] = Run( 718 episode_name, self.project_path, params=self.data.loc[episode_name] 719 ).get_num_classes() 720 return parameters
Load the task parameters from a record.
Parameters
episode_name : str the name of the episode to load
Returns
parameters : dict the loaded task parameters
722 def get_active_datasets(self) -> List: 723 """Get a list of names of datasets that are used by unfinished episodes. 724 725 Returns 726 ------- 727 active_datasets : list 728 a list of dataset names used by unfinished episodes 729 730 """ 731 active_datasets = [] 732 for episode_name in self.unfinished_episodes(): 733 run = Run( 734 episode_name, self.project_path, params=self.data.loc[episode_name] 735 ) 736 active_datasets.append(run.dataset_name()) 737 return active_datasets
Get a list of names of datasets that are used by unfinished episodes.
Returns
active_datasets : list a list of dataset names used by unfinished episodes
739 def list_episodes( 740 self, 741 episode_names: List = None, 742 value_filter: str = "", 743 display_parameters: List = None, 744 ) -> pd.DataFrame: 745 """Get a filtered pandas dataframe with episode metadata. 746 747 Parameters 748 ---------- 749 episode_names : List 750 a list of strings of episode names 751 value_filter : str 752 a string of filters to apply of this general structure: 753 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 754 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' 755 display_parameters : List 756 list of parameters to display (e.g. ['data/overlap', 'results/recall']) 757 758 Returns 759 ------- 760 pandas.DataFrame 761 the filtered dataframe 762 763 """ 764 if episode_names is not None: 765 data = deepcopy(self.data.loc[episode_names]) 766 else: 767 data = deepcopy(self.data) 768 if len(data) == 0: 769 return pd.DataFrame() 770 try: 771 filters = value_filter.split(",") 772 if filters == [""]: 773 filters = [] 774 for f in filters: 775 par_name, condition = f.split("::") 776 group_name, par_name = par_name.split("/") 777 sign, value = condition[0], condition[1:] 778 if value[0] == "=": 779 sign += "=" 780 value = value[1:] 781 try: 782 value = float(value) 783 except: 784 if value == "True": 785 value = True 786 elif value == "False": 787 value = False 788 elif value == "None": 789 value = None 790 if value is None: 791 if sign == "=": 792 data = data[data[group_name][par_name].isna()] 793 elif sign == "!=": 794 data = data[~data[group_name][par_name].isna()] 795 elif sign == ">": 796 data = data[data[group_name][par_name] > value] 797 elif sign == ">=": 798 data = data[data[group_name][par_name] >= value] 799 elif sign == "<": 800 data = data[data[group_name][par_name] < value] 801 elif sign == "<=": 802 data = data[data[group_name][par_name] <= value] 803 elif sign == "=": 804 data = data[data[group_name][par_name] == value] 805 elif sign == "!=": 806 data = data[data[group_name][par_name] != value] 807 else: 808 raise ValueError( 809 "Please use one of the signs: [>, <, >=, <=, =, !=]" 810 ) 811 except ValueError: 812 raise ValueError( 813 f"The {value_filter} filter is not valid, please use the following format:" 814 f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', " 815 f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'" 816 ) 817 if display_parameters is not None: 818 if type(display_parameters[0]) is str: 819 display_parameters = [ 820 (x.split("/")[0], x.split("/")[1]) for x in display_parameters 821 ] 822 display_parameters = [x for x in display_parameters if x in data.columns] 823 data = data[display_parameters] 824 return data
Get a filtered pandas dataframe with episode metadata.
Parameters
episode_names : List a list of strings of episode names value_filter : str a string of filters to apply of this general structure: 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' display_parameters : List list of parameters to display (e.g. ['data/overlap', 'results/recall'])
Returns
pandas.DataFrame the filtered dataframe
826 def rename_episode(self, episode_name, new_episode_name): 827 """Rename an episode. 828 829 Parameters 830 ---------- 831 episode_name : str 832 the name of the episode to rename 833 new_episode_name : str 834 the new name of the episode 835 836 """ 837 if episode_name in self.data.index and new_episode_name not in self.data.index: 838 self.data.loc[new_episode_name] = self.data.loc[episode_name] 839 model_path = self.data.loc[new_episode_name, ("training", "model_path")] 840 self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join( 841 os.path.dirname(model_path), new_episode_name 842 ) 843 log_path = self.data.loc[new_episode_name, ("training", "log_file")] 844 self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join( 845 os.path.dirname(log_path), f"{new_episode_name}.txt" 846 ) 847 self.data = self.data.drop(index=episode_name) 848 self._save() 849 else: 850 raise ValueError("The names are wrong")
Rename an episode.
Parameters
episode_name : str the name of the episode to rename new_episode_name : str the new name of the episode
852 def remove_episode(self, episode_name: str) -> None: 853 """Remove all model, logs and metafile records related to an episode. 854 855 Parameters 856 ---------- 857 episode_name : str 858 the name of the episode to remove 859 860 """ 861 if episode_name in self.data.index: 862 self.data = self.data.drop(index=episode_name) 863 self._save()
Remove all model, logs and metafile records related to an episode.
Parameters
episode_name : str the name of the episode to remove
865 def unfinished_episodes(self) -> List: 866 """Get a list of unfinished episodes (currently running or interrupted). 867 868 Returns 869 ------- 870 interrupted_episodes: List 871 a list of string names of unfinished episodes in the records 872 873 """ 874 unfinished = [] 875 for name, params in self.data.iterrows(): 876 if Run(name, project_path=self.project_path, params=params).unfinished(): 877 unfinished.append(name) 878 return unfinished
Get a list of unfinished episodes (currently running or interrupted).
Returns
interrupted_episodes: List a list of string names of unfinished episodes in the records
880 def update_episode_results( 881 self, 882 episode_name: str, 883 logs: Tuple, 884 training_time: str = None, 885 ) -> None: 886 """Add results to an episode record. 887 888 Parameters 889 ---------- 890 episode_name : str 891 the name of the episode to update 892 logs : dict 893 a log dictionary from task.train() 894 training_time : str 895 the training time 896 897 """ 898 metrics_log = logs[1] 899 results = {} 900 for key, value in metrics_log["val"].items(): 901 results[("results", key)] = value[-1] 902 if training_time is not None: 903 results[("meta", "training_time")] = training_time 904 for k, v in results.items(): 905 self.data.loc[episode_name, k] = v 906 self._save()
Add results to an episode record.
Parameters
episode_name : str the name of the episode to update logs : dict a log dictionary from task.train() training_time : str the training time
908 def get_runs(self, episode_name: str) -> List: 909 """Get a list of runs with this episode name (episodes like `episode_name#0`). 910 911 Parameters 912 ---------- 913 episode_name : str 914 the name of the episode 915 916 Returns 917 ------- 918 runs_list : List 919 a list of string run names 920 921 """ 922 if episode_name is None: 923 return [] 924 index = self.data.index 925 runs_list = [] 926 for name in index: 927 if name.startswith(episode_name): 928 if "::" in name: 929 split = name.split("::") 930 else: 931 split = name.split("#") 932 if split[0] == episode_name: 933 if len(split) > 1 and split[-1].isnumeric() or len(split) == 1: 934 runs_list.append(name) 935 elif name == episode_name: 936 runs_list.append(name) 937 return runs_list
Get a list of runs with this episode name (episodes like episode_name#0).
Parameters
episode_name : str the name of the episode
Returns
runs_list : List a list of string run names
944class Searches(SavedRuns): 945 """A class that manages operations with search records.""" 946 947 def save_search( 948 self, 949 search_name: str, 950 parameters: Dict, 951 n_trials: int, 952 best_params: Dict, 953 best_value: float, 954 metric: str, 955 search_space: Dict, 956 ) -> None: 957 """Save a new search record. 958 959 Parameters 960 ---------- 961 search_name : str 962 the name of the search to save 963 parameters : dict 964 the task parameters to save 965 n_trials : int 966 the number of trials in the search 967 best_params : dict 968 the best parameters dictionary 969 best_value : float 970 the best valie 971 metric : str 972 the name of the objective metric 973 search_space : dict 974 a dictionary representing the search space; of this general structure: 975 {'group/param_name': ('float/int/float_log/int_log', start, end), 976 'group/param_name': ('categorical', [choices])}, e.g. 977 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 978 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])} 979 980 """ 981 pars = deepcopy(parameters) 982 pars["results"] = {"best_value": best_value, "best_params": best_params} 983 pars["meta"] = { 984 "objective": metric, 985 "n_trials": n_trials, 986 "search_space": search_space, 987 } 988 self.save_episode(search_name, pars, {}) 989 990 def get_best_params_raw(self, search_name: str) -> Dict: 991 """Get the raw dictionary of best parameters found by a search. 992 993 Parameters 994 ---------- 995 search_name : str 996 the name of the search 997 998 Returns 999 ------- 1000 best_params : dict 1001 a dictionary of the best parameters where the keys are in '{group}/{name}' format 1002 1003 """ 1004 return self.data.loc[search_name]["results"]["best_params"] 1005 1006 def get_best_params( 1007 self, 1008 search_name: str, 1009 load_parameters: List = None, 1010 round_to_binary: List = None, 1011 ) -> Dict: 1012 """Get the best parameters from a search. 1013 1014 Parameters 1015 ---------- 1016 search_name : str 1017 the name of the search 1018 load_parameters : List, optional 1019 a list of string names of the parameters to load (if not provided all parameters are loaded) 1020 round_to_binary : List, optional 1021 a list of string names of the loaded parameters that should be rounded to the nearest power of two 1022 1023 Returns 1024 ------- 1025 best_params : dict 1026 a dictionary of the best parameters 1027 1028 """ 1029 if round_to_binary is None: 1030 round_to_binary = [] 1031 params = self.data.loc[search_name]["results"]["best_params"] 1032 if load_parameters is not None: 1033 params = {k: v for k, v in params.items() if k in load_parameters} 1034 for par_name in round_to_binary: 1035 if par_name not in params: 1036 continue 1037 if not isinstance(params[par_name], float) and not isinstance( 1038 params[par_name], int 1039 ): 1040 raise TypeError( 1041 f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two" 1042 ) 1043 i = 1 1044 while 2**i < params[par_name]: 1045 i += 1 1046 if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]: 1047 params[par_name] = 2 ** (i - 1) 1048 else: 1049 params[par_name] = 2**i 1050 res = defaultdict(lambda: defaultdict(lambda: {})) 1051 for k, v in params.items(): 1052 big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:]) 1053 if len(small_key.split("/")) == 1: 1054 res[big_key][small_key] = v 1055 else: 1056 group, key = small_key.split("/") 1057 res[big_key][group][key] = v 1058 model = self.data.loc[search_name]["general"]["model_name"] 1059 return res, model
A class that manages operations with search records.
947 def save_search( 948 self, 949 search_name: str, 950 parameters: Dict, 951 n_trials: int, 952 best_params: Dict, 953 best_value: float, 954 metric: str, 955 search_space: Dict, 956 ) -> None: 957 """Save a new search record. 958 959 Parameters 960 ---------- 961 search_name : str 962 the name of the search to save 963 parameters : dict 964 the task parameters to save 965 n_trials : int 966 the number of trials in the search 967 best_params : dict 968 the best parameters dictionary 969 best_value : float 970 the best valie 971 metric : str 972 the name of the objective metric 973 search_space : dict 974 a dictionary representing the search space; of this general structure: 975 {'group/param_name': ('float/int/float_log/int_log', start, end), 976 'group/param_name': ('categorical', [choices])}, e.g. 977 {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 978 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])} 979 980 """ 981 pars = deepcopy(parameters) 982 pars["results"] = {"best_value": best_value, "best_params": best_params} 983 pars["meta"] = { 984 "objective": metric, 985 "n_trials": n_trials, 986 "search_space": search_space, 987 } 988 self.save_episode(search_name, pars, {})
Save a new search record.
Parameters
search_name : str the name of the search to save parameters : dict the task parameters to save n_trials : int the number of trials in the search best_params : dict the best parameters dictionary best_value : float the best valie metric : str the name of the objective metric search_space : dict a dictionary representing the search space; of this general structure: {'group/param_name': ('float/int/float_log/int_log', start, end), 'group/param_name': ('categorical', [choices])}, e.g. {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2), 'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
990 def get_best_params_raw(self, search_name: str) -> Dict: 991 """Get the raw dictionary of best parameters found by a search. 992 993 Parameters 994 ---------- 995 search_name : str 996 the name of the search 997 998 Returns 999 ------- 1000 best_params : dict 1001 a dictionary of the best parameters where the keys are in '{group}/{name}' format 1002 1003 """ 1004 return self.data.loc[search_name]["results"]["best_params"]
Get the raw dictionary of best parameters found by a search.
Parameters
search_name : str the name of the search
Returns
best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format
1006 def get_best_params( 1007 self, 1008 search_name: str, 1009 load_parameters: List = None, 1010 round_to_binary: List = None, 1011 ) -> Dict: 1012 """Get the best parameters from a search. 1013 1014 Parameters 1015 ---------- 1016 search_name : str 1017 the name of the search 1018 load_parameters : List, optional 1019 a list of string names of the parameters to load (if not provided all parameters are loaded) 1020 round_to_binary : List, optional 1021 a list of string names of the loaded parameters that should be rounded to the nearest power of two 1022 1023 Returns 1024 ------- 1025 best_params : dict 1026 a dictionary of the best parameters 1027 1028 """ 1029 if round_to_binary is None: 1030 round_to_binary = [] 1031 params = self.data.loc[search_name]["results"]["best_params"] 1032 if load_parameters is not None: 1033 params = {k: v for k, v in params.items() if k in load_parameters} 1034 for par_name in round_to_binary: 1035 if par_name not in params: 1036 continue 1037 if not isinstance(params[par_name], float) and not isinstance( 1038 params[par_name], int 1039 ): 1040 raise TypeError( 1041 f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two" 1042 ) 1043 i = 1 1044 while 2**i < params[par_name]: 1045 i += 1 1046 if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]: 1047 params[par_name] = 2 ** (i - 1) 1048 else: 1049 params[par_name] = 2**i 1050 res = defaultdict(lambda: defaultdict(lambda: {})) 1051 for k, v in params.items(): 1052 big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:]) 1053 if len(small_key.split("/")) == 1: 1054 res[big_key][small_key] = v 1055 else: 1056 group, key = small_key.split("/") 1057 res[big_key][group][key] = v 1058 model = self.data.loc[search_name]["general"]["model_name"] 1059 return res, model
Get the best parameters from a search.
Parameters
search_name : str the name of the search load_parameters : List, optional a list of string names of the parameters to load (if not provided all parameters are loaded) round_to_binary : List, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two
Returns
best_params : dict a dictionary of the best parameters
1062class Suggestions(SavedRuns): 1063 """A class that manages operations with suggestion records.""" 1064 1065 def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters): 1066 """Save a new suggestion record.""" 1067 pars = deepcopy(parameters) 1068 pars["meta"] = meta_parameters 1069 super().save_episode(episode_name, pars, behaviors_dict=None)
A class that manages operations with suggestion records.
1065 def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters): 1066 """Save a new suggestion record.""" 1067 pars = deepcopy(parameters) 1068 pars["meta"] = meta_parameters 1069 super().save_episode(episode_name, pars, behaviors_dict=None)
Save a new suggestion record.
1072class SavedStores: 1073 """A class that manages operations with saved dataset records.""" 1074 1075 def __init__(self, path): 1076 """Initialize the class. 1077 1078 Parameters 1079 ---------- 1080 path : str 1081 the path to the pickled SavedRuns dataframe 1082 1083 """ 1084 self.path = path 1085 self.data = pd.read_pickle(path) 1086 self.skip_keys = [ 1087 "feature_save_path", 1088 "saved_data_path", 1089 "real_lens", 1090 "recompute_annotation", 1091 ] 1092 1093 def clear(self) -> None: 1094 """Remove all datasets.""" 1095 for dataset_name in self.data.index: 1096 self.remove_dataset(dataset_name) 1097 1098 def dataset_names(self) -> List: 1099 """Get a list of dataset names. 1100 1101 Returns 1102 ------- 1103 dataset_names : List 1104 a list of string dataset names 1105 1106 """ 1107 return list(self.data.index) 1108 1109 def remove(self, names: List) -> None: 1110 """Remove some datasets. 1111 1112 Parameters 1113 ---------- 1114 names : List 1115 a list of string names of the datasets to delete 1116 1117 """ 1118 for dataset_name in names: 1119 if dataset_name in self.data.index: 1120 self.remove_dataset(dataset_name) 1121 1122 def remove_dataset(self, dataset_name: str) -> None: 1123 """Remove a dataset record. 1124 1125 Parameters 1126 ---------- 1127 dataset_name : str 1128 the name of the dataset to remove 1129 1130 """ 1131 if dataset_name in self.data.index: 1132 self.data = self.data.drop(index=dataset_name) 1133 self._save() 1134 1135 def find_name(self, parameters: Dict) -> str: 1136 """Find a record that satisfies the parameters (if it exists). 1137 1138 Parameters 1139 ---------- 1140 parameters : dict 1141 a dictionary of data parameters 1142 1143 Returns 1144 ------- 1145 name : str 1146 the name of a record that has the same parameters (None if it does not exist; the earliest if there are 1147 several) 1148 1149 """ 1150 filter = deepcopy(parameters) 1151 for key, value in parameters.items(): 1152 if value is None or key in self.skip_keys: 1153 filter.pop(key) 1154 elif key not in self.data.columns: 1155 return None 1156 saved_annotation = self.data[ 1157 (self.data[list(filter)] == pd.Series(filter)).all(axis=1) 1158 ] 1159 for i in range(len(saved_annotation)): 1160 ok = True 1161 for key in saved_annotation.columns: 1162 if key in self.skip_keys: 1163 continue 1164 isnull = pd.isnull(saved_annotation.iloc[i][key]) 1165 if not isinstance(isnull, bool): 1166 isnull = False 1167 if key not in filter and not isnull: 1168 ok = False 1169 if ok: 1170 name = saved_annotation.iloc[i].name 1171 return name 1172 return None 1173 1174 def save_store(self, episode_name: str, parameters: Dict) -> None: 1175 """Save a new saved dataset record. 1176 1177 Parameters 1178 ---------- 1179 episode_name : str 1180 the name of the dataset 1181 parameters : dict 1182 a dictionary of data parameters 1183 1184 """ 1185 pars = deepcopy(parameters) 1186 for k, v in parameters.items(): 1187 if k not in self.data.columns: 1188 self.data[k] = np.nan 1189 if self.find_name(pars) is None: 1190 self.data.loc[episode_name] = pars 1191 self._save() 1192 1193 def _save(self): 1194 """Save the dataframe.""" 1195 self.data.to_pickle(self.path) 1196 1197 def check_name_validity(self, store_name: str) -> bool: 1198 """Check if a store name already exists. 1199 1200 Parameters 1201 ---------- 1202 store_name : str 1203 the name to check 1204 1205 Returns 1206 ------- 1207 result : bool 1208 True if the name can be used 1209 1210 """ 1211 if store_name in self.data.index: 1212 return False 1213 else: 1214 return True
A class that manages operations with saved dataset records.
1075 def __init__(self, path): 1076 """Initialize the class. 1077 1078 Parameters 1079 ---------- 1080 path : str 1081 the path to the pickled SavedRuns dataframe 1082 1083 """ 1084 self.path = path 1085 self.data = pd.read_pickle(path) 1086 self.skip_keys = [ 1087 "feature_save_path", 1088 "saved_data_path", 1089 "real_lens", 1090 "recompute_annotation", 1091 ]
Initialize the class.
Parameters
path : str the path to the pickled SavedRuns dataframe
1093 def clear(self) -> None: 1094 """Remove all datasets.""" 1095 for dataset_name in self.data.index: 1096 self.remove_dataset(dataset_name)
Remove all datasets.
1098 def dataset_names(self) -> List: 1099 """Get a list of dataset names. 1100 1101 Returns 1102 ------- 1103 dataset_names : List 1104 a list of string dataset names 1105 1106 """ 1107 return list(self.data.index)
Get a list of dataset names.
Returns
dataset_names : List a list of string dataset names
1109 def remove(self, names: List) -> None: 1110 """Remove some datasets. 1111 1112 Parameters 1113 ---------- 1114 names : List 1115 a list of string names of the datasets to delete 1116 1117 """ 1118 for dataset_name in names: 1119 if dataset_name in self.data.index: 1120 self.remove_dataset(dataset_name)
Remove some datasets.
Parameters
names : List a list of string names of the datasets to delete
1122 def remove_dataset(self, dataset_name: str) -> None: 1123 """Remove a dataset record. 1124 1125 Parameters 1126 ---------- 1127 dataset_name : str 1128 the name of the dataset to remove 1129 1130 """ 1131 if dataset_name in self.data.index: 1132 self.data = self.data.drop(index=dataset_name) 1133 self._save()
Remove a dataset record.
Parameters
dataset_name : str the name of the dataset to remove
1135 def find_name(self, parameters: Dict) -> str: 1136 """Find a record that satisfies the parameters (if it exists). 1137 1138 Parameters 1139 ---------- 1140 parameters : dict 1141 a dictionary of data parameters 1142 1143 Returns 1144 ------- 1145 name : str 1146 the name of a record that has the same parameters (None if it does not exist; the earliest if there are 1147 several) 1148 1149 """ 1150 filter = deepcopy(parameters) 1151 for key, value in parameters.items(): 1152 if value is None or key in self.skip_keys: 1153 filter.pop(key) 1154 elif key not in self.data.columns: 1155 return None 1156 saved_annotation = self.data[ 1157 (self.data[list(filter)] == pd.Series(filter)).all(axis=1) 1158 ] 1159 for i in range(len(saved_annotation)): 1160 ok = True 1161 for key in saved_annotation.columns: 1162 if key in self.skip_keys: 1163 continue 1164 isnull = pd.isnull(saved_annotation.iloc[i][key]) 1165 if not isinstance(isnull, bool): 1166 isnull = False 1167 if key not in filter and not isnull: 1168 ok = False 1169 if ok: 1170 name = saved_annotation.iloc[i].name 1171 return name 1172 return None
Find a record that satisfies the parameters (if it exists).
Parameters
parameters : dict a dictionary of data parameters
Returns
name : str the name of a record that has the same parameters (None if it does not exist; the earliest if there are several)
1174 def save_store(self, episode_name: str, parameters: Dict) -> None: 1175 """Save a new saved dataset record. 1176 1177 Parameters 1178 ---------- 1179 episode_name : str 1180 the name of the dataset 1181 parameters : dict 1182 a dictionary of data parameters 1183 1184 """ 1185 pars = deepcopy(parameters) 1186 for k, v in parameters.items(): 1187 if k not in self.data.columns: 1188 self.data[k] = np.nan 1189 if self.find_name(pars) is None: 1190 self.data.loc[episode_name] = pars 1191 self._save()
Save a new saved dataset record.
Parameters
episode_name : str the name of the dataset parameters : dict a dictionary of data parameters
1197 def check_name_validity(self, store_name: str) -> bool: 1198 """Check if a store name already exists. 1199 1200 Parameters 1201 ---------- 1202 store_name : str 1203 the name to check 1204 1205 Returns 1206 ------- 1207 result : bool 1208 True if the name can be used 1209 1210 """ 1211 if store_name in self.data.index: 1212 return False 1213 else: 1214 return True
Check if a store name already exists.
Parameters
store_name : str the name to check
Returns
result : bool True if the name can be used