dlc2action.project.meta

Handling meta (history) files

   1#
   2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
   5#
   6"""
   7Handling meta (history) files
   8"""
   9
  10import os
  11from typing import Dict, List, Tuple, Set, Union
  12import pandas as pd
  13from time import localtime, strftime
  14from copy import deepcopy
  15import numpy as np
  16from collections import defaultdict
  17import warnings
  18from dlc2action.utils import correct_path
  19
  20
  21class Run:
  22    """
  23    A class that manages operations with a single episode record
  24    """
  25
  26    def __init__(
  27        self,
  28        episode_name: str,
  29        project_path: str,
  30        meta_path: str = None,
  31        params: Dict = None,
  32    ):
  33        """
  34        Parameters
  35        ----------
  36        episode_name : str
  37            the name of the episode
  38        meta_path : str, optional
  39            the path to the pickled SavedRuns dataframe
  40        params : dict, optional
  41            alternative to meta_path: pre-loaded pandas Series of episode parameters
  42        """
  43
  44        self.name = episode_name
  45        self.project_path = project_path
  46        if meta_path is not None:
  47            try:
  48                self.params = pd.read_pickle(meta_path).loc[episode_name]
  49            except:
  50                raise ValueError(f"The {episode_name} episode does not exist!")
  51        elif params is not None:
  52            self.params = params
  53        else:
  54            raise ValueError("Either meta_path or params has to be not None")
  55
  56    def training_time(self) -> int:
  57        """
  58        Get the training time in seconds
  59
  60        Returns
  61        -------
  62        training_time : int
  63            the training time in seconds
  64        """
  65
  66        time_str = self.params["meta"].get("training_time")
  67        try:
  68            if time_str is None or np.isnan(time_str):
  69                return np.nan
  70        except TypeError:
  71            pass
  72        h, m, s = time_str.split(":")
  73        seconds = int(h) * 3600 + int(m) * 60 + int(s)
  74        return seconds
  75
  76    def model_file(self, load_epoch: int = None) -> str:
  77        """
  78        Get a checkpoint file path
  79
  80        Parameters
  81        ----------
  82        project_path : str
  83            the current project folder path
  84        load_epoch : int, optional
  85            the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)
  86
  87        Returns
  88        -------
  89        checkpoint_path : str
  90            the path to the checkpoint
  91        """
  92
  93        model_path = correct_path(
  94            self.params["training"]["model_save_path"], self.project_path
  95        )
  96        if load_epoch is None:
  97            model_file = sorted(os.listdir(model_path))[-1]
  98        else:
  99            model_files = os.listdir(model_path)
 100            if len(model_files) == 0:
 101                model_file = None
 102            else:
 103                epochs = [int(file[5:].split(".")[0]) for file in model_files]
 104                diffs = [np.abs(epoch - load_epoch) for epoch in epochs]
 105                argmin = np.argmin(diffs)
 106                model_file = model_files[argmin]
 107        model_file = os.path.join(model_path, model_file)
 108        return model_file
 109
 110    def dataset_name(self) -> str:
 111        """
 112        Get the dataset name
 113
 114        Returns
 115        -------
 116        dataset_name : str
 117            the name of the dataset record
 118        """
 119
 120        data_path = correct_path(
 121            self.params["data"]["feature_save_path"], self.project_path
 122        )
 123        dataset_name = os.path.basename(data_path)
 124        return dataset_name
 125
 126    def split_file(self) -> str:
 127        """
 128        Get the split file
 129
 130        Returns
 131        -------
 132        split_path : str
 133            the path to the split file
 134        """
 135
 136        return correct_path(self.params["training"]["split_path"], self.project_path)
 137
 138    def log_file(self) -> str:
 139        """
 140        Get the log file
 141
 142        Returns
 143        -------
 144        log_path : str
 145            the path to the log file
 146        """
 147
 148        return correct_path(self.params["training"]["log_file"], self.project_path)
 149
 150    def split_info(self) -> Dict:
 151        """
 152        Get the train/test/val split information
 153
 154        Returns
 155        -------
 156        split_info : dict
 157            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values
 158        """
 159
 160        val_frac = self.params["training"]["val_frac"]
 161        test_frac = self.params["training"]["test_frac"]
 162        partition_method = self.params["training"]["partition_method"]
 163        return {
 164            "val_frac": val_frac,
 165            "test_frac": test_frac,
 166            "partition_method": partition_method,
 167        }
 168
 169    def same_split_info(self, split_info: Dict) -> bool:
 170        """
 171        Check whether this episode has the same split information
 172
 173        Parameters
 174        ----------
 175        split_info : dict
 176            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode
 177
 178        Returns
 179        -------
 180        result : bool
 181            if True, this episode has the same split information
 182        """
 183
 184        self_split_info = self.split_info()
 185        for k in ["val_frac", "test_frac", "partition_method"]:
 186            if self_split_info[k] != split_info[k]:
 187                return False
 188        return True
 189
 190    def get_metrics(self) -> List:
 191        """
 192        Get a list of tracked metrics
 193
 194        Returns
 195        -------
 196        metrics : list
 197            a list of tracked metric names
 198        """
 199
 200        return self.params["general"]["metric_functions"]
 201
 202    def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray:
 203        """
 204        Get the metric log
 205
 206        Parameters
 207        ----------
 208        mode : {'train', 'val'}
 209            the mode to get the log from
 210        metric_name : str
 211            the metric to get the log for (has to be one of the metric computed for this episode during training)
 212
 213        Returns
 214        -------
 215        log : np.ndarray
 216            the log of metric values (empty if the metric was not computed during training)
 217        """
 218
 219        metric_array = []
 220        with open(self.log_file()) as f:
 221            for line in f.readlines():
 222                if mode == "train" and line.startswith("[epoch"):
 223                    line = line.split("]: ")[1]
 224                elif mode == "val" and line.startswith("validation"):
 225                    line = line.split("validation: ")[1]
 226                else:
 227                    continue
 228                metrics = line.split(", ")
 229                for metric in metrics:
 230                    name, value = metric.split()
 231                    if name == metric_name:
 232                        metric_array.append(float(value))
 233        return np.array(metric_array)
 234
 235    def get_epoch_list(self, mode) -> List:
 236        """
 237        Get a list of epoch indices
 238
 239        Returns
 240        -------
 241        epoch_list : list
 242            a list of int epoch indices
 243        """
 244
 245        epoch_list = []
 246        with open(self.log_file()) as f:
 247            for line in f.readlines():
 248                if line.startswith("[epoch"):
 249                    epoch = int(line[7:].split("]:")[0])
 250                    if mode == "train":
 251                        epoch_list.append(epoch)
 252                elif mode == "val":
 253                    epoch_list.append(epoch)
 254        return epoch_list
 255
 256    def get_metrics(self) -> List:
 257        """
 258        Get a list of metric names in the episode log
 259
 260        Returns
 261        -------
 262        metrics : List
 263            a list of string metric names
 264        """
 265
 266        metrics = []
 267        with open(self.log_file()) as f:
 268            for line in f.readlines():
 269                if line.startswith("[epoch"):
 270                    line = line.split("]: ")[1]
 271                elif line.startswith("validation"):
 272                    line = line.split("validation: ")[1]
 273                else:
 274                    continue
 275                metric_logs = line.split(", ")
 276                for metric in metric_logs:
 277                    name, _ = metric.split()
 278                    metrics.append(name)
 279                break
 280        return metrics
 281
 282    def unfinished(self) -> bool:
 283        """
 284        Check whether this episode was interrupted
 285
 286        Returns
 287        -------
 288        result : bool
 289            True if the number of epochs in the log file is smaller than in the parameters
 290        """
 291
 292        num_epoch_theor = self.params["training"]["num_epochs"]
 293        log_file = self.log_file()
 294        if not isinstance(log_file, str):
 295            return False
 296        if not os.path.exists(log_file):
 297            return True
 298        with open(self.log_file()) as f:
 299            num_epoch = 0
 300            val = False
 301            for line in f.readlines():
 302                num_epoch += 1
 303                if num_epoch == 2 and line.startswith("validation"):
 304                    val = True
 305            if val:
 306                num_epoch //= 2
 307        return num_epoch < num_epoch_theor
 308
 309    def get_class_ind(self, class_name: str) -> int:
 310        """
 311        Get the integer label from a class name
 312
 313        Parameters
 314        ----------
 315        class_name : str
 316            the name of the class
 317
 318        Returns
 319        -------
 320        class_ind : int
 321            the integer label
 322        """
 323
 324        behaviors_dict = self.params["meta"]["behaviors_dict"]
 325        for k, v in behaviors_dict.items():
 326            if v == class_name:
 327                return k
 328        raise ValueError(
 329            f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})"
 330        )
 331
 332    def get_behaviors_dict(self) -> Dict:
 333        """
 334        Get behaviors dictionary in the episode
 335
 336        Returns
 337        -------
 338        behaviors_dict : dict
 339            a dictionary with class indices as keys and labels as values
 340        """
 341
 342        return self.params["meta"]["behaviors_dict"]
 343
 344    def get_num_classes(self) -> int:
 345        """
 346        Get number of classes in episode
 347
 348        Returns
 349        -------
 350        num_classes : int
 351            the number of classes
 352        """
 353
 354        return len(self.params["meta"]["behaviors_dict"])
 355
 356
 357class DecisionThresholds:
 358    """
 359    A class that saves and looks up tuned decision thresholds
 360    """
 361
 362    def __init__(self, path: str) -> None:
 363        """
 364        Parameters
 365        ----------
 366        path : str
 367            the path to the pickled SavedRuns dataframe
 368        """
 369
 370        self.path = path
 371        self.data = pd.read_pickle(path)
 372
 373    def save_thresholds(
 374        self,
 375        episode_names: List,
 376        epochs: List,
 377        metric_name: str,
 378        metric_parameters: Dict,
 379        thresholds: List,
 380    ) -> None:
 381        """
 382        Add a new record
 383
 384        Parameters
 385        ----------
 386        episode_name : str
 387            the name of the episode
 388        epoch : int
 389            the epoch index
 390        metric_name : str
 391            the name of the metric the thresholds were tuned on
 392        metric_parameters : dict
 393            the metric parameter dictionary
 394        thresholds : list
 395            a list of float decision thresholds
 396        """
 397
 398        episodes = set(zip(episode_names, epochs))
 399        for key in ["average", "threshold_value", "ignored_classes"]:
 400            if key in metric_parameters:
 401                metric_parameters.pop(key)
 402        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
 403        parameters["thresholds"] = thresholds
 404        parameters["episodes"] = episodes
 405        pars = {k: [v] for k, v in parameters.items()}
 406        self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0)
 407        self._save()
 408
 409    def find_thresholds(
 410        self,
 411        episode_names: List,
 412        epochs: List,
 413        metric_name: str,
 414        metric_parameters: Dict,
 415    ) -> Union[List, None]:
 416        """
 417        Find a record
 418
 419        Parameters
 420        ----------
 421        episode_name : str
 422            the name of the episode
 423        epoch : int
 424            the epoch index
 425        metric_name : str
 426            the name of the metric the thresholds were tuned on
 427        metric_parameters : dict
 428            the metric parameter dictionary
 429
 430        Returns
 431        -------
 432        thresholds : list
 433            a list of float decision thresholds
 434        """
 435
 436        episodes = set(zip(episode_names, epochs))
 437        for key in ["average", "threshold_value", "ignored_classes"]:
 438            if key in metric_parameters:
 439                metric_parameters.pop(key)
 440        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
 441        parameters["episodes"] = episodes
 442        filter = deepcopy(parameters)
 443        for key, value in parameters.items():
 444            if value is None:
 445                filter.pop(key)
 446            elif key not in self.data.columns:
 447                return None
 448        data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)]
 449        if len(data) > 0:
 450            thresholds = data.iloc[0]["thresholds"]
 451            return thresholds
 452        else:
 453            return None
 454
 455    def _save(self) -> None:
 456        """
 457        Save the records
 458        """
 459
 460        self.data.copy().to_pickle(self.path)
 461
 462
 463class SavedRuns:
 464    """
 465    A class that manages operations with all episode (or prediction) records
 466    """
 467
 468    def __init__(self, path: str, project_path: str) -> None:
 469        """
 470        Parameters
 471        ----------
 472        path : str
 473            the path to the pickled SavedRuns dataframe
 474        """
 475
 476        self.path = path
 477        self.project_path = project_path
 478        self.data = pd.read_pickle(path)
 479
 480    def update(
 481        self,
 482        data: pd.DataFrame,
 483        data_path: str,
 484        annotation_path: str,
 485        name_map: Dict = None,
 486        force: bool = False,
 487    ) -> None:
 488        """
 489        Update with new data
 490
 491        Parameters
 492        ----------
 493        data : pd.DataFrame
 494            the new dataframe
 495        data_path : str
 496            the new data path
 497        annotation_path : str
 498            the new annotation path
 499        name_map : dict, optional
 500            the name change dictionary; keys are old episode names and values are new episode names
 501        force : bool, default False
 502            replace existing episodes if `True`
 503        """
 504
 505        if name_map is None:
 506            name_map = {}
 507        data = data.rename(index=name_map)
 508        for episode in data.index:
 509            new_model = os.path.join(self.project_path, "results", "model", episode)
 510            data.loc[episode, ("training", "model_save_path")] = new_model
 511            new_log = os.path.join(
 512                self.project_path, "results", "logs", f"{episode}.txt"
 513            )
 514            data.loc[episode, ("training", "log_file")] = new_log
 515            old_split = data.loc[episode, ("training", "split_path")]
 516            if old_split is None:
 517                new_split = None
 518            else:
 519                new_split = os.path.join(
 520                    self.project_path, "results", "splits", os.path.basename(old_split)
 521                )
 522            data.loc[episode, ("training", "split_path")] = new_split
 523            data.loc[episode, ("data", "data_path")] = data_path
 524            data.loc[episode, ("data", "annotation_path")] = annotation_path
 525            if episode in self.data.index:
 526                if force:
 527                    self.data = self.data.drop(index=[episode])
 528                else:
 529                    raise RuntimeError(f"The {episode} episode name is already taken!")
 530        self.data = pd.concat([self.data, data])
 531        self._save()
 532
 533    def get_subset(self, episode_names: List) -> pd.DataFrame:
 534        """
 535        Get a subset of the raw metadata
 536
 537        Parameters
 538        ----------
 539        episode_names : list
 540            a list of the episodes to include
 541        """
 542
 543        for episode in episode_names:
 544            if episode not in self.data.index:
 545                raise ValueError(
 546                    f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records"
 547                )
 548        return self.data.loc[episode_names]
 549
 550    def get_saved_data_path(self, episode_name: str) -> str:
 551        """
 552        Get the `saved_data_path` parameter for the episode
 553
 554        Parameters
 555        ----------
 556        episode_name : str
 557            the name of the episode
 558
 559        Returns
 560        -------
 561        saved_data_path : str
 562            the saved data path
 563        """
 564
 565        return self.data.loc[episode_name]["data"]["saved_data_path"]
 566
 567    def check_name_validity(self, episode_name: str) -> bool:
 568        """
 569        Check if an episode name already exists
 570
 571        Parameters
 572        ----------
 573        episode_name : str
 574            the name to check
 575
 576        Returns
 577        -------
 578        result : bool
 579            True if the name can be used
 580        """
 581
 582        if episode_name in self.data.index:
 583            return False
 584        else:
 585            return True
 586
 587    def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None:
 588        """
 589        Update meta data with evaluation results
 590
 591        Parameters
 592        ----------
 593        episode_name : str
 594            the name of the episode to update
 595        metrics : dict
 596            a dictionary of the metrics
 597        """
 598
 599        for key, value in metrics.items():
 600            self.data.loc[episode_name, ("results", key)] = value
 601        self._save()
 602
 603    def save_episode(
 604        self,
 605        episode_name: str,
 606        parameters: Dict,
 607        behaviors_dict: Dict,
 608        suppress_validation: bool = False,
 609        training_time: str = None,
 610    ) -> None:
 611        """
 612        Save a new run record
 613
 614        Parameters
 615        ----------
 616        episode_name : str
 617            the name of the episode
 618        parameters : dict
 619            the parameters to save
 620        behaviors_dict : dict
 621            the dictionary of behaviors (keys are indices, values are names)
 622        suppress_validation : bool, optional False
 623            if True, existing episode with the same name will be overwritten
 624        training_time : str, optional
 625            the training time in '%H:%M:%S' format
 626        """
 627
 628        if not suppress_validation and episode_name in self.data.index:
 629            raise ValueError(f"Episode {episode_name} already exists!")
 630        pars = deepcopy(parameters)
 631        if "meta" not in pars:
 632            pars["meta"] = {
 633                "time": strftime("%Y-%m-%d %H:%M:%S", localtime()),
 634                "behaviors_dict": behaviors_dict,
 635            }
 636        else:
 637            pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime())
 638            pars["meta"]["behaviors_dict"] = behaviors_dict
 639        if training_time is not None:
 640            pars["meta"]["training_time"] = training_time
 641        if len(parameters.keys()) > 1:
 642            pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {})
 643            for metric_name in pars["general"]["metric_functions"]:
 644                pars[metric_name] = pars["metrics"].get(metric_name, {})
 645            if pars["general"].get("ssl", None) is not None:
 646                for ssl_name in pars["general"]["ssl"]:
 647                    pars[ssl_name] = pars["ssl"].get(ssl_name, {})
 648            for group_name in ["metrics", "ssl"]:
 649                if group_name in pars:
 650                    pars.pop(group_name)
 651        data = {
 652            (big_key, small_key): value
 653            for big_key, big_value in pars.items()
 654            for small_key, value in big_value.items()
 655        }
 656        list_keys = []
 657        with warnings.catch_warnings():
 658            warnings.filterwarnings("ignore", message="DataFrame is highly fragmented")
 659            for k, v in data.items():
 660                if k not in self.data.columns:
 661                    self.data[k] = np.nan
 662                if isinstance(v, list) and not isinstance(v, str):
 663                    list_keys.append(k)
 664            for k in list_keys:
 665                self.data[k] = self.data[k].astype(object)
 666            self.data.loc[episode_name] = data
 667        self._save()
 668
 669    def load_parameters(self, episode_name: str) -> Dict:
 670        """
 671        Load the task parameters from a record
 672
 673        Parameters
 674        ----------
 675        episode_name : str
 676            the name of the episode to load
 677
 678        Returns
 679        -------
 680        parameters : dict
 681            the loaded task parameters
 682        """
 683
 684        parameters = defaultdict(lambda: defaultdict(lambda: {}))
 685        episode = self.data.loc[episode_name].dropna().to_dict()
 686        keys = ["data", "augmentations", "general", "training", "model", "features"]
 687        for key in episode:
 688            big_key, small_key = key
 689            if big_key in keys:
 690                parameters[big_key][small_key] = episode[key]
 691        # parameters = {k: dict(v) for k, v in parameters.items()}
 692        ssl_keys = parameters["general"].get("ssl", None)
 693        metric_keys = parameters["general"].get("metric_functions", None)
 694        loss_key = parameters["general"]["loss_function"]
 695        if ssl_keys is None:
 696            ssl_keys = []
 697        if metric_keys is None:
 698            metric_keys = []
 699        for key in episode:
 700            big_key, small_key = key
 701            if big_key in ssl_keys:
 702                parameters["ssl"][big_key][small_key] = episode[key]
 703            elif big_key in metric_keys:
 704                parameters["metrics"][big_key][small_key] = episode[key]
 705            elif big_key == "losses":
 706                parameters["losses"][loss_key][small_key] = episode[key]
 707        parameters = {k: dict(v) for k, v in parameters.items()}
 708        parameters["general"]["num_classes"] = Run(
 709            episode_name, self.project_path, params=self.data.loc[episode_name]
 710        ).get_num_classes()
 711        return parameters
 712
 713    def get_active_datasets(self) -> List:
 714        """
 715        Get a list of names of datasets that are used by unfinished episodes
 716
 717        Returns
 718        -------
 719        active_datasets : list
 720            a list of dataset names used by unfinished episodes
 721        """
 722
 723        active_datasets = []
 724        for episode_name in self.unfinished_episodes():
 725            run = Run(
 726                episode_name, self.project_path, params=self.data.loc[episode_name]
 727            )
 728            active_datasets.append(run.dataset_name())
 729        return active_datasets
 730
 731    def list_episodes(
 732        self,
 733        episode_names: List = None,
 734        value_filter: str = "",
 735        display_parameters: List = None,
 736    ) -> pd.DataFrame:
 737        """
 738        Get a filtered pandas dataframe with episode metadata
 739
 740        Parameters
 741        ----------
 742        episode_names : List
 743            a list of strings of episode names
 744        value_filter : str
 745            a string of filters to apply of this general structure:
 746            'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g.
 747            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic'
 748        display_parameters : List
 749            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
 750
 751        Returns
 752        -------
 753        pandas.DataFrame
 754            the filtered dataframe
 755        """
 756
 757        if episode_names is not None:
 758            data = deepcopy(self.data.loc[episode_names])
 759        else:
 760            data = deepcopy(self.data)
 761        if len(data) == 0:
 762            return pd.DataFrame()
 763        try:
 764            filters = value_filter.split(",")
 765            if filters == [""]:
 766                filters = []
 767            for f in filters:
 768                par_name, condition = f.split("::")
 769                group_name, par_name = par_name.split("/")
 770                sign, value = condition[0], condition[1:]
 771                if value[0] == "=":
 772                    sign += "="
 773                    value = value[1:]
 774                try:
 775                    value = float(value)
 776                except:
 777                    if value == "True":
 778                        value = True
 779                    elif value == "False":
 780                        value = False
 781                    elif value == "None":
 782                        value = None
 783                if value is None:
 784                    if sign == "=":
 785                        data = data[data[group_name][par_name].isna()]
 786                    elif sign == "!=":
 787                        data = data[~data[group_name][par_name].isna()]
 788                elif sign == ">":
 789                    data = data[data[group_name][par_name] > value]
 790                elif sign == ">=":
 791                    data = data[data[group_name][par_name] >= value]
 792                elif sign == "<":
 793                    data = data[data[group_name][par_name] < value]
 794                elif sign == "<=":
 795                    data = data[data[group_name][par_name] <= value]
 796                elif sign == "=":
 797                    data = data[data[group_name][par_name] == value]
 798                elif sign == "!=":
 799                    data = data[data[group_name][par_name] != value]
 800                else:
 801                    raise ValueError(
 802                        "Please use one of the signs: [>, <, >=, <=, =, !=]"
 803                    )
 804        except ValueError:
 805            raise ValueError(
 806                f"The {value_filter} filter is not valid, please use the following format:"
 807                f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', "
 808                f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'"
 809            )
 810        if display_parameters is not None:
 811            if type(display_parameters[0]) is str:
 812                display_parameters = [
 813                    (x.split("/")[0], x.split("/")[1]) for x in display_parameters
 814                ]
 815            display_parameters = [x for x in display_parameters if x in data.columns]
 816            data = data[display_parameters]
 817        return data
 818
 819    def rename_episode(self, episode_name, new_episode_name):
 820        if episode_name in self.data.index and new_episode_name not in self.data.index:
 821            self.data.loc[new_episode_name] = self.data.loc[episode_name]
 822            model_path = self.data.loc[new_episode_name, ("training", "model_path")]
 823            self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join(
 824                os.path.dirname(model_path), new_episode_name
 825            )
 826            log_path = self.data.loc[new_episode_name, ("training", "log_file")]
 827            self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join(
 828                os.path.dirname(log_path), f"{new_episode_name}.txt"
 829            )
 830            self.data = self.data.drop(index=episode_name)
 831            self._save()
 832        else:
 833            raise ValueError("The names are wrong")
 834
 835    def remove_episode(self, episode_name: str) -> None:
 836        """
 837        Remove all model, logs and metafile records related to an episode
 838
 839        Parameters
 840        ----------
 841        episode_name : str
 842            the name of the episode to remove
 843        """
 844
 845        if episode_name in self.data.index:
 846            self.data = self.data.drop(index=episode_name)
 847            self._save()
 848
 849    def unfinished_episodes(self) -> List:
 850        """
 851        Get a list of unfinished episodes (currently running or interrupted)
 852
 853        Returns
 854        -------
 855        interrupted_episodes: List
 856            a list of string names of unfinished episodes in the records
 857        """
 858
 859        unfinished = []
 860        for name, params in self.data.iterrows():
 861            if Run(name, project_path=self.project_path, params=params).unfinished():
 862                unfinished.append(name)
 863        return unfinished
 864
 865    def update_episode_results(
 866        self,
 867        episode_name: str,
 868        logs: Tuple,
 869        training_time: str = None,
 870    ) -> None:
 871        """
 872        Add results to an episode record
 873
 874        Parameters
 875        ----------
 876        episode_name : str
 877            the name of the episode to update
 878        logs : dict
 879            a log dictionary from task.train()
 880        training_time : str
 881            the training time
 882        """
 883
 884        metrics_log = logs[1]
 885        results = {}
 886        for key, value in metrics_log["val"].items():
 887            results[("results", key)] = value[-1]
 888        if training_time is not None:
 889            results[("meta", "training_time")] = training_time
 890        for k, v in results.items():
 891            self.data.loc[episode_name, k] = v
 892        self._save()
 893
 894    def get_runs(self, episode_name: str) -> List:
 895        """
 896        Get a list of runs with this episode name (episodes like `episode_name::0`)
 897
 898        Parameters
 899        ----------
 900        episode_name : str
 901            the name of the episode
 902
 903        Returns
 904        -------
 905        runs_list : List
 906            a list of string run names
 907        """
 908
 909        if episode_name is None:
 910            return []
 911        index = self.data.index
 912        runs_list = []
 913        for name in index:
 914            if name.startswith(episode_name):
 915                split = name.split("::")
 916                if split[0] == episode_name:
 917                    if len(split) > 1 and split[-1].isnumeric() or len(split) == 1:
 918                        runs_list.append(name)
 919                elif name == episode_name:
 920                    runs_list.append(name)
 921        return runs_list
 922
 923    def _save(self):
 924        """
 925        Save the dataframe
 926        """
 927
 928        self.data.copy().to_pickle(self.path)
 929
 930
 931class Searches(SavedRuns):
 932    """
 933    A class that manages operations with search records
 934    """
 935
 936    def save_search(
 937        self,
 938        search_name: str,
 939        parameters: Dict,
 940        n_trials: int,
 941        best_params: Dict,
 942        best_value: float,
 943        metric: str,
 944        search_space: Dict,
 945    ) -> None:
 946        """
 947        Save a new search record
 948
 949        Parameters
 950        ----------
 951        search_name : str
 952            the name of the search to save
 953        parameters : dict
 954            the task parameters to save
 955        n_trials : int
 956            the number of trials in the search
 957        best_params : dict
 958            the best parameters dictionary
 959        best_value : float
 960            the best valie
 961        metric : str
 962            the name of the objective metric
 963        search_space : dict
 964            a dictionary representing the search space; of this general structure:
 965            {'group/param_name': ('float/int/float_log/int_log', start, end),
 966            'group/param_name': ('categorical', [choices])}, e.g.
 967            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
 968            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
 969        """
 970
 971        pars = deepcopy(parameters)
 972        pars["results"] = {"best_value": best_value, "best_params": best_params}
 973        pars["meta"] = {
 974            "objective": metric,
 975            "n_trials": n_trials,
 976            "search_space": search_space,
 977        }
 978        self.save_episode(search_name, pars, {})
 979
 980    def get_best_params_raw(self, search_name: str) -> Dict:
 981        """
 982        Get the raw dictionary of best parameters found by a search
 983
 984        Parameters
 985        ----------
 986        search_name : str
 987            the name of the search
 988
 989        Returns
 990        -------
 991        best_params : dict
 992            a dictionary of the best parameters where the keys are in '{group}/{name}' format
 993        """
 994
 995        return self.data.loc[search_name]["results"]["best_params"]
 996
 997    def get_best_params(
 998        self,
 999        search_name: str,
1000        load_parameters: List = None,
1001        round_to_binary: List = None,
1002    ) -> Dict:
1003        """
1004        Get the best parameters from a search
1005
1006        Parameters
1007        ----------
1008        search_name : str
1009            the name of the search
1010        load_parameters : List, optional
1011            a list of string names of the parameters to load (if not provided all parameters are loaded)
1012        round_to_binary : List, optional
1013            a list of string names of the loaded parameters that should be rounded to the nearest power of two
1014
1015        Returns
1016        -------
1017        best_params : dict
1018            a dictionary of the best parameters
1019        """
1020
1021        if round_to_binary is None:
1022            round_to_binary = []
1023        params = self.data.loc[search_name]["results"]["best_params"]
1024        if load_parameters is not None:
1025            params = {k: v for k, v in params.items() if k in load_parameters}
1026        for par_name in round_to_binary:
1027            if par_name not in params:
1028                continue
1029            if not isinstance(params[par_name], float) and not isinstance(
1030                params[par_name], int
1031            ):
1032                raise TypeError(
1033                    f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two"
1034                )
1035            i = 1
1036            while 2**i < params[par_name]:
1037                i += 1
1038            if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]:
1039                params[par_name] = 2 ** (i - 1)
1040            else:
1041                params[par_name] = 2**i
1042        res = defaultdict(lambda: defaultdict(lambda: {}))
1043        for k, v in params.items():
1044            big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:])
1045            if len(small_key.split("/")) == 1:
1046                res[big_key][small_key] = v
1047            else:
1048                group, key = small_key.split("/")
1049                res[big_key][group][key] = v
1050        model = self.data.loc[search_name]["general"]["model_name"]
1051        return res, model
1052
1053
1054class Suggestions(SavedRuns):
1055    def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters):
1056        pars = deepcopy(parameters)
1057        pars["meta"] = meta_parameters
1058        super().save_episode(episode_name, pars, behaviors_dict=None)
1059
1060
1061class SavedStores:
1062    """
1063    A class that manages operation with saved dataset records
1064    """
1065
1066    def __init__(self, path):
1067        """
1068        Parameters
1069        ----------
1070        path : str
1071            the path to the pickled SavedRuns dataframe
1072        """
1073
1074        self.path = path
1075        self.data = pd.read_pickle(path)
1076        self.skip_keys = [
1077            "feature_save_path",
1078            "saved_data_path",
1079            "real_lens",
1080            "recompute_annotation",
1081        ]
1082
1083    def clear(self) -> None:
1084        """
1085        Remove all datasets
1086        """
1087
1088        for dataset_name in self.data.index:
1089            self.remove_dataset(dataset_name)
1090
1091    def dataset_names(self) -> List:
1092        """
1093        Get a list of dataset names
1094
1095        Returns
1096        -------
1097        dataset_names : List
1098            a list of string dataset names
1099        """
1100
1101        return list(self.data.index)
1102
1103    def remove(self, names: List) -> None:
1104        """
1105        Remove some datasets
1106
1107        Parameters
1108        ----------
1109        names : List
1110            a list of string names of the datasets to delete
1111        """
1112
1113        for dataset_name in names:
1114            if dataset_name in self.data.index:
1115                self.remove_dataset(dataset_name)
1116
1117    def remove_dataset(self, dataset_name: str) -> None:
1118        """
1119        Remove a dataset record
1120
1121        Parameters
1122        ----------
1123        dataset_name : str
1124            the name of the dataset to remove
1125        """
1126
1127        if dataset_name in self.data.index:
1128            self.data = self.data.drop(index=dataset_name)
1129            self._save()
1130
1131    def find_name(self, parameters: Dict) -> str:
1132        """
1133        Find a record that satisfies the parameters (if it exists)
1134
1135        Parameters
1136        ----------
1137        parameters : dict
1138            a dictionary of data parameters
1139
1140        Returns
1141        -------
1142        name : str
1143            the name of a record that has the same parameters (None if it does not exist; the earliest if there are
1144            several)
1145        """
1146
1147        filter = deepcopy(parameters)
1148        for key, value in parameters.items():
1149            if value is None or key in self.skip_keys:
1150                filter.pop(key)
1151            elif key not in self.data.columns:
1152                return None
1153        saved_annotation = self.data[
1154            (self.data[list(filter)] == pd.Series(filter)).all(axis=1)
1155        ]
1156        for i in range(len(saved_annotation)):
1157            ok = True
1158            for key in saved_annotation.columns:
1159                if key in self.skip_keys:
1160                    continue
1161                isnull = pd.isnull(saved_annotation.iloc[i][key])
1162                if not isinstance(isnull, bool):
1163                    isnull = False
1164                if key not in filter and not isnull:
1165                    ok = False
1166            if ok:
1167                name = saved_annotation.iloc[i].name
1168                return name
1169        return None
1170
1171    def save_store(self, episode_name: str, parameters: Dict) -> None:
1172        """
1173        Save a new saved dataset record
1174
1175        Parameters
1176        ----------
1177        episode_name : str
1178            the name of the dataset
1179        parameters : dict
1180            a dictionary of data parameters
1181        """
1182
1183        pars = deepcopy(parameters)
1184        for k, v in parameters.items():
1185            if k not in self.data.columns:
1186                self.data[k] = np.nan
1187        if self.find_name(pars) is None:
1188            self.data.loc[episode_name] = pars
1189        self._save()
1190
1191    def _save(self):
1192        """
1193        Save the dataframe
1194        """
1195
1196        self.data.to_pickle(self.path)
1197
1198    def check_name_validity(self, store_name: str) -> bool:
1199        """
1200        Check if a store name already exists
1201
1202        Parameters
1203        ----------
1204        episode_name : str
1205            the name to check
1206
1207        Returns
1208        -------
1209        result : bool
1210            True if the name can be used
1211        """
1212
1213        if store_name in self.data.index:
1214            return False
1215        else:
1216            return True
class Run:
 22class Run:
 23    """
 24    A class that manages operations with a single episode record
 25    """
 26
 27    def __init__(
 28        self,
 29        episode_name: str,
 30        project_path: str,
 31        meta_path: str = None,
 32        params: Dict = None,
 33    ):
 34        """
 35        Parameters
 36        ----------
 37        episode_name : str
 38            the name of the episode
 39        meta_path : str, optional
 40            the path to the pickled SavedRuns dataframe
 41        params : dict, optional
 42            alternative to meta_path: pre-loaded pandas Series of episode parameters
 43        """
 44
 45        self.name = episode_name
 46        self.project_path = project_path
 47        if meta_path is not None:
 48            try:
 49                self.params = pd.read_pickle(meta_path).loc[episode_name]
 50            except:
 51                raise ValueError(f"The {episode_name} episode does not exist!")
 52        elif params is not None:
 53            self.params = params
 54        else:
 55            raise ValueError("Either meta_path or params has to be not None")
 56
 57    def training_time(self) -> int:
 58        """
 59        Get the training time in seconds
 60
 61        Returns
 62        -------
 63        training_time : int
 64            the training time in seconds
 65        """
 66
 67        time_str = self.params["meta"].get("training_time")
 68        try:
 69            if time_str is None or np.isnan(time_str):
 70                return np.nan
 71        except TypeError:
 72            pass
 73        h, m, s = time_str.split(":")
 74        seconds = int(h) * 3600 + int(m) * 60 + int(s)
 75        return seconds
 76
 77    def model_file(self, load_epoch: int = None) -> str:
 78        """
 79        Get a checkpoint file path
 80
 81        Parameters
 82        ----------
 83        project_path : str
 84            the current project folder path
 85        load_epoch : int, optional
 86            the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)
 87
 88        Returns
 89        -------
 90        checkpoint_path : str
 91            the path to the checkpoint
 92        """
 93
 94        model_path = correct_path(
 95            self.params["training"]["model_save_path"], self.project_path
 96        )
 97        if load_epoch is None:
 98            model_file = sorted(os.listdir(model_path))[-1]
 99        else:
100            model_files = os.listdir(model_path)
101            if len(model_files) == 0:
102                model_file = None
103            else:
104                epochs = [int(file[5:].split(".")[0]) for file in model_files]
105                diffs = [np.abs(epoch - load_epoch) for epoch in epochs]
106                argmin = np.argmin(diffs)
107                model_file = model_files[argmin]
108        model_file = os.path.join(model_path, model_file)
109        return model_file
110
111    def dataset_name(self) -> str:
112        """
113        Get the dataset name
114
115        Returns
116        -------
117        dataset_name : str
118            the name of the dataset record
119        """
120
121        data_path = correct_path(
122            self.params["data"]["feature_save_path"], self.project_path
123        )
124        dataset_name = os.path.basename(data_path)
125        return dataset_name
126
127    def split_file(self) -> str:
128        """
129        Get the split file
130
131        Returns
132        -------
133        split_path : str
134            the path to the split file
135        """
136
137        return correct_path(self.params["training"]["split_path"], self.project_path)
138
139    def log_file(self) -> str:
140        """
141        Get the log file
142
143        Returns
144        -------
145        log_path : str
146            the path to the log file
147        """
148
149        return correct_path(self.params["training"]["log_file"], self.project_path)
150
151    def split_info(self) -> Dict:
152        """
153        Get the train/test/val split information
154
155        Returns
156        -------
157        split_info : dict
158            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values
159        """
160
161        val_frac = self.params["training"]["val_frac"]
162        test_frac = self.params["training"]["test_frac"]
163        partition_method = self.params["training"]["partition_method"]
164        return {
165            "val_frac": val_frac,
166            "test_frac": test_frac,
167            "partition_method": partition_method,
168        }
169
170    def same_split_info(self, split_info: Dict) -> bool:
171        """
172        Check whether this episode has the same split information
173
174        Parameters
175        ----------
176        split_info : dict
177            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode
178
179        Returns
180        -------
181        result : bool
182            if True, this episode has the same split information
183        """
184
185        self_split_info = self.split_info()
186        for k in ["val_frac", "test_frac", "partition_method"]:
187            if self_split_info[k] != split_info[k]:
188                return False
189        return True
190
191    def get_metrics(self) -> List:
192        """
193        Get a list of tracked metrics
194
195        Returns
196        -------
197        metrics : list
198            a list of tracked metric names
199        """
200
201        return self.params["general"]["metric_functions"]
202
203    def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray:
204        """
205        Get the metric log
206
207        Parameters
208        ----------
209        mode : {'train', 'val'}
210            the mode to get the log from
211        metric_name : str
212            the metric to get the log for (has to be one of the metric computed for this episode during training)
213
214        Returns
215        -------
216        log : np.ndarray
217            the log of metric values (empty if the metric was not computed during training)
218        """
219
220        metric_array = []
221        with open(self.log_file()) as f:
222            for line in f.readlines():
223                if mode == "train" and line.startswith("[epoch"):
224                    line = line.split("]: ")[1]
225                elif mode == "val" and line.startswith("validation"):
226                    line = line.split("validation: ")[1]
227                else:
228                    continue
229                metrics = line.split(", ")
230                for metric in metrics:
231                    name, value = metric.split()
232                    if name == metric_name:
233                        metric_array.append(float(value))
234        return np.array(metric_array)
235
236    def get_epoch_list(self, mode) -> List:
237        """
238        Get a list of epoch indices
239
240        Returns
241        -------
242        epoch_list : list
243            a list of int epoch indices
244        """
245
246        epoch_list = []
247        with open(self.log_file()) as f:
248            for line in f.readlines():
249                if line.startswith("[epoch"):
250                    epoch = int(line[7:].split("]:")[0])
251                    if mode == "train":
252                        epoch_list.append(epoch)
253                elif mode == "val":
254                    epoch_list.append(epoch)
255        return epoch_list
256
257    def get_metrics(self) -> List:
258        """
259        Get a list of metric names in the episode log
260
261        Returns
262        -------
263        metrics : List
264            a list of string metric names
265        """
266
267        metrics = []
268        with open(self.log_file()) as f:
269            for line in f.readlines():
270                if line.startswith("[epoch"):
271                    line = line.split("]: ")[1]
272                elif line.startswith("validation"):
273                    line = line.split("validation: ")[1]
274                else:
275                    continue
276                metric_logs = line.split(", ")
277                for metric in metric_logs:
278                    name, _ = metric.split()
279                    metrics.append(name)
280                break
281        return metrics
282
283    def unfinished(self) -> bool:
284        """
285        Check whether this episode was interrupted
286
287        Returns
288        -------
289        result : bool
290            True if the number of epochs in the log file is smaller than in the parameters
291        """
292
293        num_epoch_theor = self.params["training"]["num_epochs"]
294        log_file = self.log_file()
295        if not isinstance(log_file, str):
296            return False
297        if not os.path.exists(log_file):
298            return True
299        with open(self.log_file()) as f:
300            num_epoch = 0
301            val = False
302            for line in f.readlines():
303                num_epoch += 1
304                if num_epoch == 2 and line.startswith("validation"):
305                    val = True
306            if val:
307                num_epoch //= 2
308        return num_epoch < num_epoch_theor
309
310    def get_class_ind(self, class_name: str) -> int:
311        """
312        Get the integer label from a class name
313
314        Parameters
315        ----------
316        class_name : str
317            the name of the class
318
319        Returns
320        -------
321        class_ind : int
322            the integer label
323        """
324
325        behaviors_dict = self.params["meta"]["behaviors_dict"]
326        for k, v in behaviors_dict.items():
327            if v == class_name:
328                return k
329        raise ValueError(
330            f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})"
331        )
332
333    def get_behaviors_dict(self) -> Dict:
334        """
335        Get behaviors dictionary in the episode
336
337        Returns
338        -------
339        behaviors_dict : dict
340            a dictionary with class indices as keys and labels as values
341        """
342
343        return self.params["meta"]["behaviors_dict"]
344
345    def get_num_classes(self) -> int:
346        """
347        Get number of classes in episode
348
349        Returns
350        -------
351        num_classes : int
352            the number of classes
353        """
354
355        return len(self.params["meta"]["behaviors_dict"])

A class that manages operations with a single episode record

Run( episode_name: str, project_path: str, meta_path: str = None, params: Dict = None)
27    def __init__(
28        self,
29        episode_name: str,
30        project_path: str,
31        meta_path: str = None,
32        params: Dict = None,
33    ):
34        """
35        Parameters
36        ----------
37        episode_name : str
38            the name of the episode
39        meta_path : str, optional
40            the path to the pickled SavedRuns dataframe
41        params : dict, optional
42            alternative to meta_path: pre-loaded pandas Series of episode parameters
43        """
44
45        self.name = episode_name
46        self.project_path = project_path
47        if meta_path is not None:
48            try:
49                self.params = pd.read_pickle(meta_path).loc[episode_name]
50            except:
51                raise ValueError(f"The {episode_name} episode does not exist!")
52        elif params is not None:
53            self.params = params
54        else:
55            raise ValueError("Either meta_path or params has to be not None")

Parameters

episode_name : str the name of the episode meta_path : str, optional the path to the pickled SavedRuns dataframe params : dict, optional alternative to meta_path: pre-loaded pandas Series of episode parameters

def training_time(self) -> int:
57    def training_time(self) -> int:
58        """
59        Get the training time in seconds
60
61        Returns
62        -------
63        training_time : int
64            the training time in seconds
65        """
66
67        time_str = self.params["meta"].get("training_time")
68        try:
69            if time_str is None or np.isnan(time_str):
70                return np.nan
71        except TypeError:
72            pass
73        h, m, s = time_str.split(":")
74        seconds = int(h) * 3600 + int(m) * 60 + int(s)
75        return seconds

Get the training time in seconds

Returns

training_time : int the training time in seconds

def model_file(self, load_epoch: int = None) -> str:
 77    def model_file(self, load_epoch: int = None) -> str:
 78        """
 79        Get a checkpoint file path
 80
 81        Parameters
 82        ----------
 83        project_path : str
 84            the current project folder path
 85        load_epoch : int, optional
 86            the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)
 87
 88        Returns
 89        -------
 90        checkpoint_path : str
 91            the path to the checkpoint
 92        """
 93
 94        model_path = correct_path(
 95            self.params["training"]["model_save_path"], self.project_path
 96        )
 97        if load_epoch is None:
 98            model_file = sorted(os.listdir(model_path))[-1]
 99        else:
100            model_files = os.listdir(model_path)
101            if len(model_files) == 0:
102                model_file = None
103            else:
104                epochs = [int(file[5:].split(".")[0]) for file in model_files]
105                diffs = [np.abs(epoch - load_epoch) for epoch in epochs]
106                argmin = np.argmin(diffs)
107                model_file = model_files[argmin]
108        model_file = os.path.join(model_path, model_file)
109        return model_file

Get a checkpoint file path

Parameters

project_path : str the current project folder path load_epoch : int, optional the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)

Returns

checkpoint_path : str the path to the checkpoint

def dataset_name(self) -> str:
111    def dataset_name(self) -> str:
112        """
113        Get the dataset name
114
115        Returns
116        -------
117        dataset_name : str
118            the name of the dataset record
119        """
120
121        data_path = correct_path(
122            self.params["data"]["feature_save_path"], self.project_path
123        )
124        dataset_name = os.path.basename(data_path)
125        return dataset_name

Get the dataset name

Returns

dataset_name : str the name of the dataset record

def split_file(self) -> str:
127    def split_file(self) -> str:
128        """
129        Get the split file
130
131        Returns
132        -------
133        split_path : str
134            the path to the split file
135        """
136
137        return correct_path(self.params["training"]["split_path"], self.project_path)

Get the split file

Returns

split_path : str the path to the split file

def log_file(self) -> str:
139    def log_file(self) -> str:
140        """
141        Get the log file
142
143        Returns
144        -------
145        log_path : str
146            the path to the log file
147        """
148
149        return correct_path(self.params["training"]["log_file"], self.project_path)

Get the log file

Returns

log_path : str the path to the log file

def split_info(self) -> Dict:
151    def split_info(self) -> Dict:
152        """
153        Get the train/test/val split information
154
155        Returns
156        -------
157        split_info : dict
158            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values
159        """
160
161        val_frac = self.params["training"]["val_frac"]
162        test_frac = self.params["training"]["test_frac"]
163        partition_method = self.params["training"]["partition_method"]
164        return {
165            "val_frac": val_frac,
166            "test_frac": test_frac,
167            "partition_method": partition_method,
168        }

Get the train/test/val split information

Returns

split_info : dict a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values

def same_split_info(self, split_info: Dict) -> bool:
170    def same_split_info(self, split_info: Dict) -> bool:
171        """
172        Check whether this episode has the same split information
173
174        Parameters
175        ----------
176        split_info : dict
177            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode
178
179        Returns
180        -------
181        result : bool
182            if True, this episode has the same split information
183        """
184
185        self_split_info = self.split_info()
186        for k in ["val_frac", "test_frac", "partition_method"]:
187            if self_split_info[k] != split_info[k]:
188                return False
189        return True

Check whether this episode has the same split information

Parameters

split_info : dict a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode

Returns

result : bool if True, this episode has the same split information

def get_metrics(self) -> List:
257    def get_metrics(self) -> List:
258        """
259        Get a list of metric names in the episode log
260
261        Returns
262        -------
263        metrics : List
264            a list of string metric names
265        """
266
267        metrics = []
268        with open(self.log_file()) as f:
269            for line in f.readlines():
270                if line.startswith("[epoch"):
271                    line = line.split("]: ")[1]
272                elif line.startswith("validation"):
273                    line = line.split("validation: ")[1]
274                else:
275                    continue
276                metric_logs = line.split(", ")
277                for metric in metric_logs:
278                    name, _ = metric.split()
279                    metrics.append(name)
280                break
281        return metrics

Get a list of metric names in the episode log

Returns

metrics : List a list of string metric names

def get_metric_log(self, mode: str, metric_name: str) -> numpy.ndarray:
203    def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray:
204        """
205        Get the metric log
206
207        Parameters
208        ----------
209        mode : {'train', 'val'}
210            the mode to get the log from
211        metric_name : str
212            the metric to get the log for (has to be one of the metric computed for this episode during training)
213
214        Returns
215        -------
216        log : np.ndarray
217            the log of metric values (empty if the metric was not computed during training)
218        """
219
220        metric_array = []
221        with open(self.log_file()) as f:
222            for line in f.readlines():
223                if mode == "train" and line.startswith("[epoch"):
224                    line = line.split("]: ")[1]
225                elif mode == "val" and line.startswith("validation"):
226                    line = line.split("validation: ")[1]
227                else:
228                    continue
229                metrics = line.split(", ")
230                for metric in metrics:
231                    name, value = metric.split()
232                    if name == metric_name:
233                        metric_array.append(float(value))
234        return np.array(metric_array)

Get the metric log

Parameters

mode : {'train', 'val'} the mode to get the log from metric_name : str the metric to get the log for (has to be one of the metric computed for this episode during training)

Returns

log : np.ndarray the log of metric values (empty if the metric was not computed during training)

def get_epoch_list(self, mode) -> List:
236    def get_epoch_list(self, mode) -> List:
237        """
238        Get a list of epoch indices
239
240        Returns
241        -------
242        epoch_list : list
243            a list of int epoch indices
244        """
245
246        epoch_list = []
247        with open(self.log_file()) as f:
248            for line in f.readlines():
249                if line.startswith("[epoch"):
250                    epoch = int(line[7:].split("]:")[0])
251                    if mode == "train":
252                        epoch_list.append(epoch)
253                elif mode == "val":
254                    epoch_list.append(epoch)
255        return epoch_list

Get a list of epoch indices

Returns

epoch_list : list a list of int epoch indices

def unfinished(self) -> bool:
283    def unfinished(self) -> bool:
284        """
285        Check whether this episode was interrupted
286
287        Returns
288        -------
289        result : bool
290            True if the number of epochs in the log file is smaller than in the parameters
291        """
292
293        num_epoch_theor = self.params["training"]["num_epochs"]
294        log_file = self.log_file()
295        if not isinstance(log_file, str):
296            return False
297        if not os.path.exists(log_file):
298            return True
299        with open(self.log_file()) as f:
300            num_epoch = 0
301            val = False
302            for line in f.readlines():
303                num_epoch += 1
304                if num_epoch == 2 and line.startswith("validation"):
305                    val = True
306            if val:
307                num_epoch //= 2
308        return num_epoch < num_epoch_theor

Check whether this episode was interrupted

Returns

result : bool True if the number of epochs in the log file is smaller than in the parameters

def get_class_ind(self, class_name: str) -> int:
310    def get_class_ind(self, class_name: str) -> int:
311        """
312        Get the integer label from a class name
313
314        Parameters
315        ----------
316        class_name : str
317            the name of the class
318
319        Returns
320        -------
321        class_ind : int
322            the integer label
323        """
324
325        behaviors_dict = self.params["meta"]["behaviors_dict"]
326        for k, v in behaviors_dict.items():
327            if v == class_name:
328                return k
329        raise ValueError(
330            f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})"
331        )

Get the integer label from a class name

Parameters

class_name : str the name of the class

Returns

class_ind : int the integer label

def get_behaviors_dict(self) -> Dict:
333    def get_behaviors_dict(self) -> Dict:
334        """
335        Get behaviors dictionary in the episode
336
337        Returns
338        -------
339        behaviors_dict : dict
340            a dictionary with class indices as keys and labels as values
341        """
342
343        return self.params["meta"]["behaviors_dict"]

Get behaviors dictionary in the episode

Returns

behaviors_dict : dict a dictionary with class indices as keys and labels as values

def get_num_classes(self) -> int:
345    def get_num_classes(self) -> int:
346        """
347        Get number of classes in episode
348
349        Returns
350        -------
351        num_classes : int
352            the number of classes
353        """
354
355        return len(self.params["meta"]["behaviors_dict"])

Get number of classes in episode

Returns

num_classes : int the number of classes

class DecisionThresholds:
358class DecisionThresholds:
359    """
360    A class that saves and looks up tuned decision thresholds
361    """
362
363    def __init__(self, path: str) -> None:
364        """
365        Parameters
366        ----------
367        path : str
368            the path to the pickled SavedRuns dataframe
369        """
370
371        self.path = path
372        self.data = pd.read_pickle(path)
373
374    def save_thresholds(
375        self,
376        episode_names: List,
377        epochs: List,
378        metric_name: str,
379        metric_parameters: Dict,
380        thresholds: List,
381    ) -> None:
382        """
383        Add a new record
384
385        Parameters
386        ----------
387        episode_name : str
388            the name of the episode
389        epoch : int
390            the epoch index
391        metric_name : str
392            the name of the metric the thresholds were tuned on
393        metric_parameters : dict
394            the metric parameter dictionary
395        thresholds : list
396            a list of float decision thresholds
397        """
398
399        episodes = set(zip(episode_names, epochs))
400        for key in ["average", "threshold_value", "ignored_classes"]:
401            if key in metric_parameters:
402                metric_parameters.pop(key)
403        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
404        parameters["thresholds"] = thresholds
405        parameters["episodes"] = episodes
406        pars = {k: [v] for k, v in parameters.items()}
407        self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0)
408        self._save()
409
410    def find_thresholds(
411        self,
412        episode_names: List,
413        epochs: List,
414        metric_name: str,
415        metric_parameters: Dict,
416    ) -> Union[List, None]:
417        """
418        Find a record
419
420        Parameters
421        ----------
422        episode_name : str
423            the name of the episode
424        epoch : int
425            the epoch index
426        metric_name : str
427            the name of the metric the thresholds were tuned on
428        metric_parameters : dict
429            the metric parameter dictionary
430
431        Returns
432        -------
433        thresholds : list
434            a list of float decision thresholds
435        """
436
437        episodes = set(zip(episode_names, epochs))
438        for key in ["average", "threshold_value", "ignored_classes"]:
439            if key in metric_parameters:
440                metric_parameters.pop(key)
441        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
442        parameters["episodes"] = episodes
443        filter = deepcopy(parameters)
444        for key, value in parameters.items():
445            if value is None:
446                filter.pop(key)
447            elif key not in self.data.columns:
448                return None
449        data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)]
450        if len(data) > 0:
451            thresholds = data.iloc[0]["thresholds"]
452            return thresholds
453        else:
454            return None
455
456    def _save(self) -> None:
457        """
458        Save the records
459        """
460
461        self.data.copy().to_pickle(self.path)

A class that saves and looks up tuned decision thresholds

DecisionThresholds(path: str)
363    def __init__(self, path: str) -> None:
364        """
365        Parameters
366        ----------
367        path : str
368            the path to the pickled SavedRuns dataframe
369        """
370
371        self.path = path
372        self.data = pd.read_pickle(path)

Parameters

path : str the path to the pickled SavedRuns dataframe

def save_thresholds( self, episode_names: List, epochs: List, metric_name: str, metric_parameters: Dict, thresholds: List) -> None:
374    def save_thresholds(
375        self,
376        episode_names: List,
377        epochs: List,
378        metric_name: str,
379        metric_parameters: Dict,
380        thresholds: List,
381    ) -> None:
382        """
383        Add a new record
384
385        Parameters
386        ----------
387        episode_name : str
388            the name of the episode
389        epoch : int
390            the epoch index
391        metric_name : str
392            the name of the metric the thresholds were tuned on
393        metric_parameters : dict
394            the metric parameter dictionary
395        thresholds : list
396            a list of float decision thresholds
397        """
398
399        episodes = set(zip(episode_names, epochs))
400        for key in ["average", "threshold_value", "ignored_classes"]:
401            if key in metric_parameters:
402                metric_parameters.pop(key)
403        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
404        parameters["thresholds"] = thresholds
405        parameters["episodes"] = episodes
406        pars = {k: [v] for k, v in parameters.items()}
407        self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0)
408        self._save()

Add a new record

Parameters

episode_name : str the name of the episode epoch : int the epoch index metric_name : str the name of the metric the thresholds were tuned on metric_parameters : dict the metric parameter dictionary thresholds : list a list of float decision thresholds

def find_thresholds( self, episode_names: List, epochs: List, metric_name: str, metric_parameters: Dict) -> Optional[List]:
410    def find_thresholds(
411        self,
412        episode_names: List,
413        epochs: List,
414        metric_name: str,
415        metric_parameters: Dict,
416    ) -> Union[List, None]:
417        """
418        Find a record
419
420        Parameters
421        ----------
422        episode_name : str
423            the name of the episode
424        epoch : int
425            the epoch index
426        metric_name : str
427            the name of the metric the thresholds were tuned on
428        metric_parameters : dict
429            the metric parameter dictionary
430
431        Returns
432        -------
433        thresholds : list
434            a list of float decision thresholds
435        """
436
437        episodes = set(zip(episode_names, epochs))
438        for key in ["average", "threshold_value", "ignored_classes"]:
439            if key in metric_parameters:
440                metric_parameters.pop(key)
441        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
442        parameters["episodes"] = episodes
443        filter = deepcopy(parameters)
444        for key, value in parameters.items():
445            if value is None:
446                filter.pop(key)
447            elif key not in self.data.columns:
448                return None
449        data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)]
450        if len(data) > 0:
451            thresholds = data.iloc[0]["thresholds"]
452            return thresholds
453        else:
454            return None

Find a record

Parameters

episode_name : str the name of the episode epoch : int the epoch index metric_name : str the name of the metric the thresholds were tuned on metric_parameters : dict the metric parameter dictionary

Returns

thresholds : list a list of float decision thresholds

class SavedRuns:
464class SavedRuns:
465    """
466    A class that manages operations with all episode (or prediction) records
467    """
468
469    def __init__(self, path: str, project_path: str) -> None:
470        """
471        Parameters
472        ----------
473        path : str
474            the path to the pickled SavedRuns dataframe
475        """
476
477        self.path = path
478        self.project_path = project_path
479        self.data = pd.read_pickle(path)
480
481    def update(
482        self,
483        data: pd.DataFrame,
484        data_path: str,
485        annotation_path: str,
486        name_map: Dict = None,
487        force: bool = False,
488    ) -> None:
489        """
490        Update with new data
491
492        Parameters
493        ----------
494        data : pd.DataFrame
495            the new dataframe
496        data_path : str
497            the new data path
498        annotation_path : str
499            the new annotation path
500        name_map : dict, optional
501            the name change dictionary; keys are old episode names and values are new episode names
502        force : bool, default False
503            replace existing episodes if `True`
504        """
505
506        if name_map is None:
507            name_map = {}
508        data = data.rename(index=name_map)
509        for episode in data.index:
510            new_model = os.path.join(self.project_path, "results", "model", episode)
511            data.loc[episode, ("training", "model_save_path")] = new_model
512            new_log = os.path.join(
513                self.project_path, "results", "logs", f"{episode}.txt"
514            )
515            data.loc[episode, ("training", "log_file")] = new_log
516            old_split = data.loc[episode, ("training", "split_path")]
517            if old_split is None:
518                new_split = None
519            else:
520                new_split = os.path.join(
521                    self.project_path, "results", "splits", os.path.basename(old_split)
522                )
523            data.loc[episode, ("training", "split_path")] = new_split
524            data.loc[episode, ("data", "data_path")] = data_path
525            data.loc[episode, ("data", "annotation_path")] = annotation_path
526            if episode in self.data.index:
527                if force:
528                    self.data = self.data.drop(index=[episode])
529                else:
530                    raise RuntimeError(f"The {episode} episode name is already taken!")
531        self.data = pd.concat([self.data, data])
532        self._save()
533
534    def get_subset(self, episode_names: List) -> pd.DataFrame:
535        """
536        Get a subset of the raw metadata
537
538        Parameters
539        ----------
540        episode_names : list
541            a list of the episodes to include
542        """
543
544        for episode in episode_names:
545            if episode not in self.data.index:
546                raise ValueError(
547                    f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records"
548                )
549        return self.data.loc[episode_names]
550
551    def get_saved_data_path(self, episode_name: str) -> str:
552        """
553        Get the `saved_data_path` parameter for the episode
554
555        Parameters
556        ----------
557        episode_name : str
558            the name of the episode
559
560        Returns
561        -------
562        saved_data_path : str
563            the saved data path
564        """
565
566        return self.data.loc[episode_name]["data"]["saved_data_path"]
567
568    def check_name_validity(self, episode_name: str) -> bool:
569        """
570        Check if an episode name already exists
571
572        Parameters
573        ----------
574        episode_name : str
575            the name to check
576
577        Returns
578        -------
579        result : bool
580            True if the name can be used
581        """
582
583        if episode_name in self.data.index:
584            return False
585        else:
586            return True
587
588    def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None:
589        """
590        Update meta data with evaluation results
591
592        Parameters
593        ----------
594        episode_name : str
595            the name of the episode to update
596        metrics : dict
597            a dictionary of the metrics
598        """
599
600        for key, value in metrics.items():
601            self.data.loc[episode_name, ("results", key)] = value
602        self._save()
603
604    def save_episode(
605        self,
606        episode_name: str,
607        parameters: Dict,
608        behaviors_dict: Dict,
609        suppress_validation: bool = False,
610        training_time: str = None,
611    ) -> None:
612        """
613        Save a new run record
614
615        Parameters
616        ----------
617        episode_name : str
618            the name of the episode
619        parameters : dict
620            the parameters to save
621        behaviors_dict : dict
622            the dictionary of behaviors (keys are indices, values are names)
623        suppress_validation : bool, optional False
624            if True, existing episode with the same name will be overwritten
625        training_time : str, optional
626            the training time in '%H:%M:%S' format
627        """
628
629        if not suppress_validation and episode_name in self.data.index:
630            raise ValueError(f"Episode {episode_name} already exists!")
631        pars = deepcopy(parameters)
632        if "meta" not in pars:
633            pars["meta"] = {
634                "time": strftime("%Y-%m-%d %H:%M:%S", localtime()),
635                "behaviors_dict": behaviors_dict,
636            }
637        else:
638            pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime())
639            pars["meta"]["behaviors_dict"] = behaviors_dict
640        if training_time is not None:
641            pars["meta"]["training_time"] = training_time
642        if len(parameters.keys()) > 1:
643            pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {})
644            for metric_name in pars["general"]["metric_functions"]:
645                pars[metric_name] = pars["metrics"].get(metric_name, {})
646            if pars["general"].get("ssl", None) is not None:
647                for ssl_name in pars["general"]["ssl"]:
648                    pars[ssl_name] = pars["ssl"].get(ssl_name, {})
649            for group_name in ["metrics", "ssl"]:
650                if group_name in pars:
651                    pars.pop(group_name)
652        data = {
653            (big_key, small_key): value
654            for big_key, big_value in pars.items()
655            for small_key, value in big_value.items()
656        }
657        list_keys = []
658        with warnings.catch_warnings():
659            warnings.filterwarnings("ignore", message="DataFrame is highly fragmented")
660            for k, v in data.items():
661                if k not in self.data.columns:
662                    self.data[k] = np.nan
663                if isinstance(v, list) and not isinstance(v, str):
664                    list_keys.append(k)
665            for k in list_keys:
666                self.data[k] = self.data[k].astype(object)
667            self.data.loc[episode_name] = data
668        self._save()
669
670    def load_parameters(self, episode_name: str) -> Dict:
671        """
672        Load the task parameters from a record
673
674        Parameters
675        ----------
676        episode_name : str
677            the name of the episode to load
678
679        Returns
680        -------
681        parameters : dict
682            the loaded task parameters
683        """
684
685        parameters = defaultdict(lambda: defaultdict(lambda: {}))
686        episode = self.data.loc[episode_name].dropna().to_dict()
687        keys = ["data", "augmentations", "general", "training", "model", "features"]
688        for key in episode:
689            big_key, small_key = key
690            if big_key in keys:
691                parameters[big_key][small_key] = episode[key]
692        # parameters = {k: dict(v) for k, v in parameters.items()}
693        ssl_keys = parameters["general"].get("ssl", None)
694        metric_keys = parameters["general"].get("metric_functions", None)
695        loss_key = parameters["general"]["loss_function"]
696        if ssl_keys is None:
697            ssl_keys = []
698        if metric_keys is None:
699            metric_keys = []
700        for key in episode:
701            big_key, small_key = key
702            if big_key in ssl_keys:
703                parameters["ssl"][big_key][small_key] = episode[key]
704            elif big_key in metric_keys:
705                parameters["metrics"][big_key][small_key] = episode[key]
706            elif big_key == "losses":
707                parameters["losses"][loss_key][small_key] = episode[key]
708        parameters = {k: dict(v) for k, v in parameters.items()}
709        parameters["general"]["num_classes"] = Run(
710            episode_name, self.project_path, params=self.data.loc[episode_name]
711        ).get_num_classes()
712        return parameters
713
714    def get_active_datasets(self) -> List:
715        """
716        Get a list of names of datasets that are used by unfinished episodes
717
718        Returns
719        -------
720        active_datasets : list
721            a list of dataset names used by unfinished episodes
722        """
723
724        active_datasets = []
725        for episode_name in self.unfinished_episodes():
726            run = Run(
727                episode_name, self.project_path, params=self.data.loc[episode_name]
728            )
729            active_datasets.append(run.dataset_name())
730        return active_datasets
731
732    def list_episodes(
733        self,
734        episode_names: List = None,
735        value_filter: str = "",
736        display_parameters: List = None,
737    ) -> pd.DataFrame:
738        """
739        Get a filtered pandas dataframe with episode metadata
740
741        Parameters
742        ----------
743        episode_names : List
744            a list of strings of episode names
745        value_filter : str
746            a string of filters to apply of this general structure:
747            'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g.
748            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic'
749        display_parameters : List
750            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
751
752        Returns
753        -------
754        pandas.DataFrame
755            the filtered dataframe
756        """
757
758        if episode_names is not None:
759            data = deepcopy(self.data.loc[episode_names])
760        else:
761            data = deepcopy(self.data)
762        if len(data) == 0:
763            return pd.DataFrame()
764        try:
765            filters = value_filter.split(",")
766            if filters == [""]:
767                filters = []
768            for f in filters:
769                par_name, condition = f.split("::")
770                group_name, par_name = par_name.split("/")
771                sign, value = condition[0], condition[1:]
772                if value[0] == "=":
773                    sign += "="
774                    value = value[1:]
775                try:
776                    value = float(value)
777                except:
778                    if value == "True":
779                        value = True
780                    elif value == "False":
781                        value = False
782                    elif value == "None":
783                        value = None
784                if value is None:
785                    if sign == "=":
786                        data = data[data[group_name][par_name].isna()]
787                    elif sign == "!=":
788                        data = data[~data[group_name][par_name].isna()]
789                elif sign == ">":
790                    data = data[data[group_name][par_name] > value]
791                elif sign == ">=":
792                    data = data[data[group_name][par_name] >= value]
793                elif sign == "<":
794                    data = data[data[group_name][par_name] < value]
795                elif sign == "<=":
796                    data = data[data[group_name][par_name] <= value]
797                elif sign == "=":
798                    data = data[data[group_name][par_name] == value]
799                elif sign == "!=":
800                    data = data[data[group_name][par_name] != value]
801                else:
802                    raise ValueError(
803                        "Please use one of the signs: [>, <, >=, <=, =, !=]"
804                    )
805        except ValueError:
806            raise ValueError(
807                f"The {value_filter} filter is not valid, please use the following format:"
808                f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', "
809                f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'"
810            )
811        if display_parameters is not None:
812            if type(display_parameters[0]) is str:
813                display_parameters = [
814                    (x.split("/")[0], x.split("/")[1]) for x in display_parameters
815                ]
816            display_parameters = [x for x in display_parameters if x in data.columns]
817            data = data[display_parameters]
818        return data
819
820    def rename_episode(self, episode_name, new_episode_name):
821        if episode_name in self.data.index and new_episode_name not in self.data.index:
822            self.data.loc[new_episode_name] = self.data.loc[episode_name]
823            model_path = self.data.loc[new_episode_name, ("training", "model_path")]
824            self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join(
825                os.path.dirname(model_path), new_episode_name
826            )
827            log_path = self.data.loc[new_episode_name, ("training", "log_file")]
828            self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join(
829                os.path.dirname(log_path), f"{new_episode_name}.txt"
830            )
831            self.data = self.data.drop(index=episode_name)
832            self._save()
833        else:
834            raise ValueError("The names are wrong")
835
836    def remove_episode(self, episode_name: str) -> None:
837        """
838        Remove all model, logs and metafile records related to an episode
839
840        Parameters
841        ----------
842        episode_name : str
843            the name of the episode to remove
844        """
845
846        if episode_name in self.data.index:
847            self.data = self.data.drop(index=episode_name)
848            self._save()
849
850    def unfinished_episodes(self) -> List:
851        """
852        Get a list of unfinished episodes (currently running or interrupted)
853
854        Returns
855        -------
856        interrupted_episodes: List
857            a list of string names of unfinished episodes in the records
858        """
859
860        unfinished = []
861        for name, params in self.data.iterrows():
862            if Run(name, project_path=self.project_path, params=params).unfinished():
863                unfinished.append(name)
864        return unfinished
865
866    def update_episode_results(
867        self,
868        episode_name: str,
869        logs: Tuple,
870        training_time: str = None,
871    ) -> None:
872        """
873        Add results to an episode record
874
875        Parameters
876        ----------
877        episode_name : str
878            the name of the episode to update
879        logs : dict
880            a log dictionary from task.train()
881        training_time : str
882            the training time
883        """
884
885        metrics_log = logs[1]
886        results = {}
887        for key, value in metrics_log["val"].items():
888            results[("results", key)] = value[-1]
889        if training_time is not None:
890            results[("meta", "training_time")] = training_time
891        for k, v in results.items():
892            self.data.loc[episode_name, k] = v
893        self._save()
894
895    def get_runs(self, episode_name: str) -> List:
896        """
897        Get a list of runs with this episode name (episodes like `episode_name::0`)
898
899        Parameters
900        ----------
901        episode_name : str
902            the name of the episode
903
904        Returns
905        -------
906        runs_list : List
907            a list of string run names
908        """
909
910        if episode_name is None:
911            return []
912        index = self.data.index
913        runs_list = []
914        for name in index:
915            if name.startswith(episode_name):
916                split = name.split("::")
917                if split[0] == episode_name:
918                    if len(split) > 1 and split[-1].isnumeric() or len(split) == 1:
919                        runs_list.append(name)
920                elif name == episode_name:
921                    runs_list.append(name)
922        return runs_list
923
924    def _save(self):
925        """
926        Save the dataframe
927        """
928
929        self.data.copy().to_pickle(self.path)

A class that manages operations with all episode (or prediction) records

SavedRuns(path: str, project_path: str)
469    def __init__(self, path: str, project_path: str) -> None:
470        """
471        Parameters
472        ----------
473        path : str
474            the path to the pickled SavedRuns dataframe
475        """
476
477        self.path = path
478        self.project_path = project_path
479        self.data = pd.read_pickle(path)

Parameters

path : str the path to the pickled SavedRuns dataframe

def update( self, data: pandas.core.frame.DataFrame, data_path: str, annotation_path: str, name_map: Dict = None, force: bool = False) -> None:
481    def update(
482        self,
483        data: pd.DataFrame,
484        data_path: str,
485        annotation_path: str,
486        name_map: Dict = None,
487        force: bool = False,
488    ) -> None:
489        """
490        Update with new data
491
492        Parameters
493        ----------
494        data : pd.DataFrame
495            the new dataframe
496        data_path : str
497            the new data path
498        annotation_path : str
499            the new annotation path
500        name_map : dict, optional
501            the name change dictionary; keys are old episode names and values are new episode names
502        force : bool, default False
503            replace existing episodes if `True`
504        """
505
506        if name_map is None:
507            name_map = {}
508        data = data.rename(index=name_map)
509        for episode in data.index:
510            new_model = os.path.join(self.project_path, "results", "model", episode)
511            data.loc[episode, ("training", "model_save_path")] = new_model
512            new_log = os.path.join(
513                self.project_path, "results", "logs", f"{episode}.txt"
514            )
515            data.loc[episode, ("training", "log_file")] = new_log
516            old_split = data.loc[episode, ("training", "split_path")]
517            if old_split is None:
518                new_split = None
519            else:
520                new_split = os.path.join(
521                    self.project_path, "results", "splits", os.path.basename(old_split)
522                )
523            data.loc[episode, ("training", "split_path")] = new_split
524            data.loc[episode, ("data", "data_path")] = data_path
525            data.loc[episode, ("data", "annotation_path")] = annotation_path
526            if episode in self.data.index:
527                if force:
528                    self.data = self.data.drop(index=[episode])
529                else:
530                    raise RuntimeError(f"The {episode} episode name is already taken!")
531        self.data = pd.concat([self.data, data])
532        self._save()

Update with new data

Parameters

data : pd.DataFrame the new dataframe data_path : str the new data path annotation_path : str the new annotation path name_map : dict, optional the name change dictionary; keys are old episode names and values are new episode names force : bool, default False replace existing episodes if True

def get_subset(self, episode_names: List) -> pandas.core.frame.DataFrame:
534    def get_subset(self, episode_names: List) -> pd.DataFrame:
535        """
536        Get a subset of the raw metadata
537
538        Parameters
539        ----------
540        episode_names : list
541            a list of the episodes to include
542        """
543
544        for episode in episode_names:
545            if episode not in self.data.index:
546                raise ValueError(
547                    f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records"
548                )
549        return self.data.loc[episode_names]

Get a subset of the raw metadata

Parameters

episode_names : list a list of the episodes to include

def get_saved_data_path(self, episode_name: str) -> str:
551    def get_saved_data_path(self, episode_name: str) -> str:
552        """
553        Get the `saved_data_path` parameter for the episode
554
555        Parameters
556        ----------
557        episode_name : str
558            the name of the episode
559
560        Returns
561        -------
562        saved_data_path : str
563            the saved data path
564        """
565
566        return self.data.loc[episode_name]["data"]["saved_data_path"]

Get the saved_data_path parameter for the episode

Parameters

episode_name : str the name of the episode

Returns

saved_data_path : str the saved data path

def check_name_validity(self, episode_name: str) -> bool:
568    def check_name_validity(self, episode_name: str) -> bool:
569        """
570        Check if an episode name already exists
571
572        Parameters
573        ----------
574        episode_name : str
575            the name to check
576
577        Returns
578        -------
579        result : bool
580            True if the name can be used
581        """
582
583        if episode_name in self.data.index:
584            return False
585        else:
586            return True

Check if an episode name already exists

Parameters

episode_name : str the name to check

Returns

result : bool True if the name can be used

def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None:
588    def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None:
589        """
590        Update meta data with evaluation results
591
592        Parameters
593        ----------
594        episode_name : str
595            the name of the episode to update
596        metrics : dict
597            a dictionary of the metrics
598        """
599
600        for key, value in metrics.items():
601            self.data.loc[episode_name, ("results", key)] = value
602        self._save()

Update meta data with evaluation results

Parameters

episode_name : str the name of the episode to update metrics : dict a dictionary of the metrics

def save_episode( self, episode_name: str, parameters: Dict, behaviors_dict: Dict, suppress_validation: bool = False, training_time: str = None) -> None:
604    def save_episode(
605        self,
606        episode_name: str,
607        parameters: Dict,
608        behaviors_dict: Dict,
609        suppress_validation: bool = False,
610        training_time: str = None,
611    ) -> None:
612        """
613        Save a new run record
614
615        Parameters
616        ----------
617        episode_name : str
618            the name of the episode
619        parameters : dict
620            the parameters to save
621        behaviors_dict : dict
622            the dictionary of behaviors (keys are indices, values are names)
623        suppress_validation : bool, optional False
624            if True, existing episode with the same name will be overwritten
625        training_time : str, optional
626            the training time in '%H:%M:%S' format
627        """
628
629        if not suppress_validation and episode_name in self.data.index:
630            raise ValueError(f"Episode {episode_name} already exists!")
631        pars = deepcopy(parameters)
632        if "meta" not in pars:
633            pars["meta"] = {
634                "time": strftime("%Y-%m-%d %H:%M:%S", localtime()),
635                "behaviors_dict": behaviors_dict,
636            }
637        else:
638            pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime())
639            pars["meta"]["behaviors_dict"] = behaviors_dict
640        if training_time is not None:
641            pars["meta"]["training_time"] = training_time
642        if len(parameters.keys()) > 1:
643            pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {})
644            for metric_name in pars["general"]["metric_functions"]:
645                pars[metric_name] = pars["metrics"].get(metric_name, {})
646            if pars["general"].get("ssl", None) is not None:
647                for ssl_name in pars["general"]["ssl"]:
648                    pars[ssl_name] = pars["ssl"].get(ssl_name, {})
649            for group_name in ["metrics", "ssl"]:
650                if group_name in pars:
651                    pars.pop(group_name)
652        data = {
653            (big_key, small_key): value
654            for big_key, big_value in pars.items()
655            for small_key, value in big_value.items()
656        }
657        list_keys = []
658        with warnings.catch_warnings():
659            warnings.filterwarnings("ignore", message="DataFrame is highly fragmented")
660            for k, v in data.items():
661                if k not in self.data.columns:
662                    self.data[k] = np.nan
663                if isinstance(v, list) and not isinstance(v, str):
664                    list_keys.append(k)
665            for k in list_keys:
666                self.data[k] = self.data[k].astype(object)
667            self.data.loc[episode_name] = data
668        self._save()

Save a new run record

Parameters

episode_name : str the name of the episode parameters : dict the parameters to save behaviors_dict : dict the dictionary of behaviors (keys are indices, values are names) suppress_validation : bool, optional False if True, existing episode with the same name will be overwritten training_time : str, optional the training time in '%H:%M:%S' format

def load_parameters(self, episode_name: str) -> Dict:
670    def load_parameters(self, episode_name: str) -> Dict:
671        """
672        Load the task parameters from a record
673
674        Parameters
675        ----------
676        episode_name : str
677            the name of the episode to load
678
679        Returns
680        -------
681        parameters : dict
682            the loaded task parameters
683        """
684
685        parameters = defaultdict(lambda: defaultdict(lambda: {}))
686        episode = self.data.loc[episode_name].dropna().to_dict()
687        keys = ["data", "augmentations", "general", "training", "model", "features"]
688        for key in episode:
689            big_key, small_key = key
690            if big_key in keys:
691                parameters[big_key][small_key] = episode[key]
692        # parameters = {k: dict(v) for k, v in parameters.items()}
693        ssl_keys = parameters["general"].get("ssl", None)
694        metric_keys = parameters["general"].get("metric_functions", None)
695        loss_key = parameters["general"]["loss_function"]
696        if ssl_keys is None:
697            ssl_keys = []
698        if metric_keys is None:
699            metric_keys = []
700        for key in episode:
701            big_key, small_key = key
702            if big_key in ssl_keys:
703                parameters["ssl"][big_key][small_key] = episode[key]
704            elif big_key in metric_keys:
705                parameters["metrics"][big_key][small_key] = episode[key]
706            elif big_key == "losses":
707                parameters["losses"][loss_key][small_key] = episode[key]
708        parameters = {k: dict(v) for k, v in parameters.items()}
709        parameters["general"]["num_classes"] = Run(
710            episode_name, self.project_path, params=self.data.loc[episode_name]
711        ).get_num_classes()
712        return parameters

Load the task parameters from a record

Parameters

episode_name : str the name of the episode to load

Returns

parameters : dict the loaded task parameters

def get_active_datasets(self) -> List:
714    def get_active_datasets(self) -> List:
715        """
716        Get a list of names of datasets that are used by unfinished episodes
717
718        Returns
719        -------
720        active_datasets : list
721            a list of dataset names used by unfinished episodes
722        """
723
724        active_datasets = []
725        for episode_name in self.unfinished_episodes():
726            run = Run(
727                episode_name, self.project_path, params=self.data.loc[episode_name]
728            )
729            active_datasets.append(run.dataset_name())
730        return active_datasets

Get a list of names of datasets that are used by unfinished episodes

Returns

active_datasets : list a list of dataset names used by unfinished episodes

def list_episodes( self, episode_names: List = None, value_filter: str = '', display_parameters: List = None) -> pandas.core.frame.DataFrame:
732    def list_episodes(
733        self,
734        episode_names: List = None,
735        value_filter: str = "",
736        display_parameters: List = None,
737    ) -> pd.DataFrame:
738        """
739        Get a filtered pandas dataframe with episode metadata
740
741        Parameters
742        ----------
743        episode_names : List
744            a list of strings of episode names
745        value_filter : str
746            a string of filters to apply of this general structure:
747            'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g.
748            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic'
749        display_parameters : List
750            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
751
752        Returns
753        -------
754        pandas.DataFrame
755            the filtered dataframe
756        """
757
758        if episode_names is not None:
759            data = deepcopy(self.data.loc[episode_names])
760        else:
761            data = deepcopy(self.data)
762        if len(data) == 0:
763            return pd.DataFrame()
764        try:
765            filters = value_filter.split(",")
766            if filters == [""]:
767                filters = []
768            for f in filters:
769                par_name, condition = f.split("::")
770                group_name, par_name = par_name.split("/")
771                sign, value = condition[0], condition[1:]
772                if value[0] == "=":
773                    sign += "="
774                    value = value[1:]
775                try:
776                    value = float(value)
777                except:
778                    if value == "True":
779                        value = True
780                    elif value == "False":
781                        value = False
782                    elif value == "None":
783                        value = None
784                if value is None:
785                    if sign == "=":
786                        data = data[data[group_name][par_name].isna()]
787                    elif sign == "!=":
788                        data = data[~data[group_name][par_name].isna()]
789                elif sign == ">":
790                    data = data[data[group_name][par_name] > value]
791                elif sign == ">=":
792                    data = data[data[group_name][par_name] >= value]
793                elif sign == "<":
794                    data = data[data[group_name][par_name] < value]
795                elif sign == "<=":
796                    data = data[data[group_name][par_name] <= value]
797                elif sign == "=":
798                    data = data[data[group_name][par_name] == value]
799                elif sign == "!=":
800                    data = data[data[group_name][par_name] != value]
801                else:
802                    raise ValueError(
803                        "Please use one of the signs: [>, <, >=, <=, =, !=]"
804                    )
805        except ValueError:
806            raise ValueError(
807                f"The {value_filter} filter is not valid, please use the following format:"
808                f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', "
809                f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'"
810            )
811        if display_parameters is not None:
812            if type(display_parameters[0]) is str:
813                display_parameters = [
814                    (x.split("/")[0], x.split("/")[1]) for x in display_parameters
815                ]
816            display_parameters = [x for x in display_parameters if x in data.columns]
817            data = data[display_parameters]
818        return data

Get a filtered pandas dataframe with episode metadata

Parameters

episode_names : List a list of strings of episode names value_filter : str a string of filters to apply of this general structure: 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' display_parameters : List list of parameters to display (e.g. ['data/overlap', 'results/recall'])

Returns

pandas.DataFrame the filtered dataframe

def rename_episode(self, episode_name, new_episode_name)
820    def rename_episode(self, episode_name, new_episode_name):
821        if episode_name in self.data.index and new_episode_name not in self.data.index:
822            self.data.loc[new_episode_name] = self.data.loc[episode_name]
823            model_path = self.data.loc[new_episode_name, ("training", "model_path")]
824            self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join(
825                os.path.dirname(model_path), new_episode_name
826            )
827            log_path = self.data.loc[new_episode_name, ("training", "log_file")]
828            self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join(
829                os.path.dirname(log_path), f"{new_episode_name}.txt"
830            )
831            self.data = self.data.drop(index=episode_name)
832            self._save()
833        else:
834            raise ValueError("The names are wrong")
def remove_episode(self, episode_name: str) -> None:
836    def remove_episode(self, episode_name: str) -> None:
837        """
838        Remove all model, logs and metafile records related to an episode
839
840        Parameters
841        ----------
842        episode_name : str
843            the name of the episode to remove
844        """
845
846        if episode_name in self.data.index:
847            self.data = self.data.drop(index=episode_name)
848            self._save()

Remove all model, logs and metafile records related to an episode

Parameters

episode_name : str the name of the episode to remove

def unfinished_episodes(self) -> List:
850    def unfinished_episodes(self) -> List:
851        """
852        Get a list of unfinished episodes (currently running or interrupted)
853
854        Returns
855        -------
856        interrupted_episodes: List
857            a list of string names of unfinished episodes in the records
858        """
859
860        unfinished = []
861        for name, params in self.data.iterrows():
862            if Run(name, project_path=self.project_path, params=params).unfinished():
863                unfinished.append(name)
864        return unfinished

Get a list of unfinished episodes (currently running or interrupted)

Returns

interrupted_episodes: List a list of string names of unfinished episodes in the records

def update_episode_results(self, episode_name: str, logs: Tuple, training_time: str = None) -> None:
866    def update_episode_results(
867        self,
868        episode_name: str,
869        logs: Tuple,
870        training_time: str = None,
871    ) -> None:
872        """
873        Add results to an episode record
874
875        Parameters
876        ----------
877        episode_name : str
878            the name of the episode to update
879        logs : dict
880            a log dictionary from task.train()
881        training_time : str
882            the training time
883        """
884
885        metrics_log = logs[1]
886        results = {}
887        for key, value in metrics_log["val"].items():
888            results[("results", key)] = value[-1]
889        if training_time is not None:
890            results[("meta", "training_time")] = training_time
891        for k, v in results.items():
892            self.data.loc[episode_name, k] = v
893        self._save()

Add results to an episode record

Parameters

episode_name : str the name of the episode to update logs : dict a log dictionary from task.train() training_time : str the training time

def get_runs(self, episode_name: str) -> List:
895    def get_runs(self, episode_name: str) -> List:
896        """
897        Get a list of runs with this episode name (episodes like `episode_name::0`)
898
899        Parameters
900        ----------
901        episode_name : str
902            the name of the episode
903
904        Returns
905        -------
906        runs_list : List
907            a list of string run names
908        """
909
910        if episode_name is None:
911            return []
912        index = self.data.index
913        runs_list = []
914        for name in index:
915            if name.startswith(episode_name):
916                split = name.split("::")
917                if split[0] == episode_name:
918                    if len(split) > 1 and split[-1].isnumeric() or len(split) == 1:
919                        runs_list.append(name)
920                elif name == episode_name:
921                    runs_list.append(name)
922        return runs_list

Get a list of runs with this episode name (episodes like episode_name::0)

Parameters

episode_name : str the name of the episode

Returns

runs_list : List a list of string run names

class Searches(SavedRuns):
 932class Searches(SavedRuns):
 933    """
 934    A class that manages operations with search records
 935    """
 936
 937    def save_search(
 938        self,
 939        search_name: str,
 940        parameters: Dict,
 941        n_trials: int,
 942        best_params: Dict,
 943        best_value: float,
 944        metric: str,
 945        search_space: Dict,
 946    ) -> None:
 947        """
 948        Save a new search record
 949
 950        Parameters
 951        ----------
 952        search_name : str
 953            the name of the search to save
 954        parameters : dict
 955            the task parameters to save
 956        n_trials : int
 957            the number of trials in the search
 958        best_params : dict
 959            the best parameters dictionary
 960        best_value : float
 961            the best valie
 962        metric : str
 963            the name of the objective metric
 964        search_space : dict
 965            a dictionary representing the search space; of this general structure:
 966            {'group/param_name': ('float/int/float_log/int_log', start, end),
 967            'group/param_name': ('categorical', [choices])}, e.g.
 968            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
 969            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
 970        """
 971
 972        pars = deepcopy(parameters)
 973        pars["results"] = {"best_value": best_value, "best_params": best_params}
 974        pars["meta"] = {
 975            "objective": metric,
 976            "n_trials": n_trials,
 977            "search_space": search_space,
 978        }
 979        self.save_episode(search_name, pars, {})
 980
 981    def get_best_params_raw(self, search_name: str) -> Dict:
 982        """
 983        Get the raw dictionary of best parameters found by a search
 984
 985        Parameters
 986        ----------
 987        search_name : str
 988            the name of the search
 989
 990        Returns
 991        -------
 992        best_params : dict
 993            a dictionary of the best parameters where the keys are in '{group}/{name}' format
 994        """
 995
 996        return self.data.loc[search_name]["results"]["best_params"]
 997
 998    def get_best_params(
 999        self,
1000        search_name: str,
1001        load_parameters: List = None,
1002        round_to_binary: List = None,
1003    ) -> Dict:
1004        """
1005        Get the best parameters from a search
1006
1007        Parameters
1008        ----------
1009        search_name : str
1010            the name of the search
1011        load_parameters : List, optional
1012            a list of string names of the parameters to load (if not provided all parameters are loaded)
1013        round_to_binary : List, optional
1014            a list of string names of the loaded parameters that should be rounded to the nearest power of two
1015
1016        Returns
1017        -------
1018        best_params : dict
1019            a dictionary of the best parameters
1020        """
1021
1022        if round_to_binary is None:
1023            round_to_binary = []
1024        params = self.data.loc[search_name]["results"]["best_params"]
1025        if load_parameters is not None:
1026            params = {k: v for k, v in params.items() if k in load_parameters}
1027        for par_name in round_to_binary:
1028            if par_name not in params:
1029                continue
1030            if not isinstance(params[par_name], float) and not isinstance(
1031                params[par_name], int
1032            ):
1033                raise TypeError(
1034                    f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two"
1035                )
1036            i = 1
1037            while 2**i < params[par_name]:
1038                i += 1
1039            if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]:
1040                params[par_name] = 2 ** (i - 1)
1041            else:
1042                params[par_name] = 2**i
1043        res = defaultdict(lambda: defaultdict(lambda: {}))
1044        for k, v in params.items():
1045            big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:])
1046            if len(small_key.split("/")) == 1:
1047                res[big_key][small_key] = v
1048            else:
1049                group, key = small_key.split("/")
1050                res[big_key][group][key] = v
1051        model = self.data.loc[search_name]["general"]["model_name"]
1052        return res, model

A class that manages operations with search records

def get_best_params_raw(self, search_name: str) -> Dict:
981    def get_best_params_raw(self, search_name: str) -> Dict:
982        """
983        Get the raw dictionary of best parameters found by a search
984
985        Parameters
986        ----------
987        search_name : str
988            the name of the search
989
990        Returns
991        -------
992        best_params : dict
993            a dictionary of the best parameters where the keys are in '{group}/{name}' format
994        """
995
996        return self.data.loc[search_name]["results"]["best_params"]

Get the raw dictionary of best parameters found by a search

Parameters

search_name : str the name of the search

Returns

best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format

def get_best_params( self, search_name: str, load_parameters: List = None, round_to_binary: List = None) -> Dict:
 998    def get_best_params(
 999        self,
1000        search_name: str,
1001        load_parameters: List = None,
1002        round_to_binary: List = None,
1003    ) -> Dict:
1004        """
1005        Get the best parameters from a search
1006
1007        Parameters
1008        ----------
1009        search_name : str
1010            the name of the search
1011        load_parameters : List, optional
1012            a list of string names of the parameters to load (if not provided all parameters are loaded)
1013        round_to_binary : List, optional
1014            a list of string names of the loaded parameters that should be rounded to the nearest power of two
1015
1016        Returns
1017        -------
1018        best_params : dict
1019            a dictionary of the best parameters
1020        """
1021
1022        if round_to_binary is None:
1023            round_to_binary = []
1024        params = self.data.loc[search_name]["results"]["best_params"]
1025        if load_parameters is not None:
1026            params = {k: v for k, v in params.items() if k in load_parameters}
1027        for par_name in round_to_binary:
1028            if par_name not in params:
1029                continue
1030            if not isinstance(params[par_name], float) and not isinstance(
1031                params[par_name], int
1032            ):
1033                raise TypeError(
1034                    f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two"
1035                )
1036            i = 1
1037            while 2**i < params[par_name]:
1038                i += 1
1039            if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]:
1040                params[par_name] = 2 ** (i - 1)
1041            else:
1042                params[par_name] = 2**i
1043        res = defaultdict(lambda: defaultdict(lambda: {}))
1044        for k, v in params.items():
1045            big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:])
1046            if len(small_key.split("/")) == 1:
1047                res[big_key][small_key] = v
1048            else:
1049                group, key = small_key.split("/")
1050                res[big_key][group][key] = v
1051        model = self.data.loc[search_name]["general"]["model_name"]
1052        return res, model

Get the best parameters from a search

Parameters

search_name : str the name of the search load_parameters : List, optional a list of string names of the parameters to load (if not provided all parameters are loaded) round_to_binary : List, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two

Returns

best_params : dict a dictionary of the best parameters

class Suggestions(SavedRuns):
1055class Suggestions(SavedRuns):
1056    def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters):
1057        pars = deepcopy(parameters)
1058        pars["meta"] = meta_parameters
1059        super().save_episode(episode_name, pars, behaviors_dict=None)

A class that manages operations with all episode (or prediction) records

def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters)
1056    def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters):
1057        pars = deepcopy(parameters)
1058        pars["meta"] = meta_parameters
1059        super().save_episode(episode_name, pars, behaviors_dict=None)
class SavedStores:
1062class SavedStores:
1063    """
1064    A class that manages operation with saved dataset records
1065    """
1066
1067    def __init__(self, path):
1068        """
1069        Parameters
1070        ----------
1071        path : str
1072            the path to the pickled SavedRuns dataframe
1073        """
1074
1075        self.path = path
1076        self.data = pd.read_pickle(path)
1077        self.skip_keys = [
1078            "feature_save_path",
1079            "saved_data_path",
1080            "real_lens",
1081            "recompute_annotation",
1082        ]
1083
1084    def clear(self) -> None:
1085        """
1086        Remove all datasets
1087        """
1088
1089        for dataset_name in self.data.index:
1090            self.remove_dataset(dataset_name)
1091
1092    def dataset_names(self) -> List:
1093        """
1094        Get a list of dataset names
1095
1096        Returns
1097        -------
1098        dataset_names : List
1099            a list of string dataset names
1100        """
1101
1102        return list(self.data.index)
1103
1104    def remove(self, names: List) -> None:
1105        """
1106        Remove some datasets
1107
1108        Parameters
1109        ----------
1110        names : List
1111            a list of string names of the datasets to delete
1112        """
1113
1114        for dataset_name in names:
1115            if dataset_name in self.data.index:
1116                self.remove_dataset(dataset_name)
1117
1118    def remove_dataset(self, dataset_name: str) -> None:
1119        """
1120        Remove a dataset record
1121
1122        Parameters
1123        ----------
1124        dataset_name : str
1125            the name of the dataset to remove
1126        """
1127
1128        if dataset_name in self.data.index:
1129            self.data = self.data.drop(index=dataset_name)
1130            self._save()
1131
1132    def find_name(self, parameters: Dict) -> str:
1133        """
1134        Find a record that satisfies the parameters (if it exists)
1135
1136        Parameters
1137        ----------
1138        parameters : dict
1139            a dictionary of data parameters
1140
1141        Returns
1142        -------
1143        name : str
1144            the name of a record that has the same parameters (None if it does not exist; the earliest if there are
1145            several)
1146        """
1147
1148        filter = deepcopy(parameters)
1149        for key, value in parameters.items():
1150            if value is None or key in self.skip_keys:
1151                filter.pop(key)
1152            elif key not in self.data.columns:
1153                return None
1154        saved_annotation = self.data[
1155            (self.data[list(filter)] == pd.Series(filter)).all(axis=1)
1156        ]
1157        for i in range(len(saved_annotation)):
1158            ok = True
1159            for key in saved_annotation.columns:
1160                if key in self.skip_keys:
1161                    continue
1162                isnull = pd.isnull(saved_annotation.iloc[i][key])
1163                if not isinstance(isnull, bool):
1164                    isnull = False
1165                if key not in filter and not isnull:
1166                    ok = False
1167            if ok:
1168                name = saved_annotation.iloc[i].name
1169                return name
1170        return None
1171
1172    def save_store(self, episode_name: str, parameters: Dict) -> None:
1173        """
1174        Save a new saved dataset record
1175
1176        Parameters
1177        ----------
1178        episode_name : str
1179            the name of the dataset
1180        parameters : dict
1181            a dictionary of data parameters
1182        """
1183
1184        pars = deepcopy(parameters)
1185        for k, v in parameters.items():
1186            if k not in self.data.columns:
1187                self.data[k] = np.nan
1188        if self.find_name(pars) is None:
1189            self.data.loc[episode_name] = pars
1190        self._save()
1191
1192    def _save(self):
1193        """
1194        Save the dataframe
1195        """
1196
1197        self.data.to_pickle(self.path)
1198
1199    def check_name_validity(self, store_name: str) -> bool:
1200        """
1201        Check if a store name already exists
1202
1203        Parameters
1204        ----------
1205        episode_name : str
1206            the name to check
1207
1208        Returns
1209        -------
1210        result : bool
1211            True if the name can be used
1212        """
1213
1214        if store_name in self.data.index:
1215            return False
1216        else:
1217            return True

A class that manages operation with saved dataset records

SavedStores(path)
1067    def __init__(self, path):
1068        """
1069        Parameters
1070        ----------
1071        path : str
1072            the path to the pickled SavedRuns dataframe
1073        """
1074
1075        self.path = path
1076        self.data = pd.read_pickle(path)
1077        self.skip_keys = [
1078            "feature_save_path",
1079            "saved_data_path",
1080            "real_lens",
1081            "recompute_annotation",
1082        ]

Parameters

path : str the path to the pickled SavedRuns dataframe

def clear(self) -> None:
1084    def clear(self) -> None:
1085        """
1086        Remove all datasets
1087        """
1088
1089        for dataset_name in self.data.index:
1090            self.remove_dataset(dataset_name)

Remove all datasets

def dataset_names(self) -> List:
1092    def dataset_names(self) -> List:
1093        """
1094        Get a list of dataset names
1095
1096        Returns
1097        -------
1098        dataset_names : List
1099            a list of string dataset names
1100        """
1101
1102        return list(self.data.index)

Get a list of dataset names

Returns

dataset_names : List a list of string dataset names

def remove(self, names: List) -> None:
1104    def remove(self, names: List) -> None:
1105        """
1106        Remove some datasets
1107
1108        Parameters
1109        ----------
1110        names : List
1111            a list of string names of the datasets to delete
1112        """
1113
1114        for dataset_name in names:
1115            if dataset_name in self.data.index:
1116                self.remove_dataset(dataset_name)

Remove some datasets

Parameters

names : List a list of string names of the datasets to delete

def remove_dataset(self, dataset_name: str) -> None:
1118    def remove_dataset(self, dataset_name: str) -> None:
1119        """
1120        Remove a dataset record
1121
1122        Parameters
1123        ----------
1124        dataset_name : str
1125            the name of the dataset to remove
1126        """
1127
1128        if dataset_name in self.data.index:
1129            self.data = self.data.drop(index=dataset_name)
1130            self._save()

Remove a dataset record

Parameters

dataset_name : str the name of the dataset to remove

def find_name(self, parameters: Dict) -> str:
1132    def find_name(self, parameters: Dict) -> str:
1133        """
1134        Find a record that satisfies the parameters (if it exists)
1135
1136        Parameters
1137        ----------
1138        parameters : dict
1139            a dictionary of data parameters
1140
1141        Returns
1142        -------
1143        name : str
1144            the name of a record that has the same parameters (None if it does not exist; the earliest if there are
1145            several)
1146        """
1147
1148        filter = deepcopy(parameters)
1149        for key, value in parameters.items():
1150            if value is None or key in self.skip_keys:
1151                filter.pop(key)
1152            elif key not in self.data.columns:
1153                return None
1154        saved_annotation = self.data[
1155            (self.data[list(filter)] == pd.Series(filter)).all(axis=1)
1156        ]
1157        for i in range(len(saved_annotation)):
1158            ok = True
1159            for key in saved_annotation.columns:
1160                if key in self.skip_keys:
1161                    continue
1162                isnull = pd.isnull(saved_annotation.iloc[i][key])
1163                if not isinstance(isnull, bool):
1164                    isnull = False
1165                if key not in filter and not isnull:
1166                    ok = False
1167            if ok:
1168                name = saved_annotation.iloc[i].name
1169                return name
1170        return None

Find a record that satisfies the parameters (if it exists)

Parameters

parameters : dict a dictionary of data parameters

Returns

name : str the name of a record that has the same parameters (None if it does not exist; the earliest if there are several)

def save_store(self, episode_name: str, parameters: Dict) -> None:
1172    def save_store(self, episode_name: str, parameters: Dict) -> None:
1173        """
1174        Save a new saved dataset record
1175
1176        Parameters
1177        ----------
1178        episode_name : str
1179            the name of the dataset
1180        parameters : dict
1181            a dictionary of data parameters
1182        """
1183
1184        pars = deepcopy(parameters)
1185        for k, v in parameters.items():
1186            if k not in self.data.columns:
1187                self.data[k] = np.nan
1188        if self.find_name(pars) is None:
1189            self.data.loc[episode_name] = pars
1190        self._save()

Save a new saved dataset record

Parameters

episode_name : str the name of the dataset parameters : dict a dictionary of data parameters

def check_name_validity(self, store_name: str) -> bool:
1199    def check_name_validity(self, store_name: str) -> bool:
1200        """
1201        Check if a store name already exists
1202
1203        Parameters
1204        ----------
1205        episode_name : str
1206            the name to check
1207
1208        Returns
1209        -------
1210        result : bool
1211            True if the name can be used
1212        """
1213
1214        if store_name in self.data.index:
1215            return False
1216        else:
1217            return True

Check if a store name already exists

Parameters

episode_name : str the name to check

Returns

result : bool True if the name can be used