dlc2action.project.meta

Handling meta (history) files.

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""Handling meta (history) files."""
   8
   9import ast
  10import os
  11import warnings
  12from collections import defaultdict
  13from copy import deepcopy
  14from time import localtime, strftime
  15from typing import Dict, List, Set, Tuple, Union
  16
  17import numpy as np
  18import pandas as pd
  19from dlc2action.utils import correct_path
  20from abc import abstractmethod
  21
  22import re
  23
  24class Run:
  25    """A class that manages operations with a single episode record."""
  26
  27    def __init__(
  28        self,
  29        episode_name: str,
  30        project_path: str,
  31        meta_path: str = None,
  32        params: Dict = None,
  33    ):
  34        """Initialize the class.
  35
  36        Parameters
  37        ----------
  38        episode_name : str
  39            the name of the episode
  40        project_path : str
  41            the path to the project folder
  42        meta_path : str, optional
  43            the path to the pickled SavedRuns dataframe
  44        params : dict, optional
  45            alternative to meta_path: pre-loaded pandas Series of episode parameters
  46
  47        """
  48        self.name = episode_name
  49        self.project_path = project_path
  50        if meta_path is not None:
  51            try:
  52                self.params = pd.read_pickle(meta_path).loc[episode_name]
  53            except:
  54                raise ValueError(f"The {episode_name} episode does not exist!")
  55        elif params is not None:
  56            self.params = params
  57        else:
  58            raise ValueError("Either meta_path or params has to be not None")
  59        self.params = self._check_str_conversion()
  60
  61    def _check_str_conversion(self):
  62        """Check if the parameters are in string format and convert them to the correct type."""
  63        return _check_str_conversion(self.params)
  64
  65    def training_time(self) -> int:
  66        """Get the training time in seconds.
  67
  68        Returns
  69        -------
  70        training_time : int
  71            the training time in seconds
  72
  73        """
  74        time_str = self.params["meta"].get("training_time")
  75        try:
  76            if time_str is None or np.isnan(time_str):
  77                return np.nan
  78        except TypeError:
  79            pass
  80        h, m, s = time_str.split(":")
  81        seconds = int(h) * 3600 + int(m) * 60 + int(s)
  82        return seconds
  83
  84    def model_file(self, load_epoch: int = None) -> str:
  85        """Get a checkpoint file path.
  86
  87        Parameters
  88        ----------
  89        load_epoch : int, optional
  90            the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)
  91
  92        Returns
  93        -------
  94        checkpoint_path : str
  95            the path to the checkpoint
  96
  97        """
  98        model_path = correct_path(
  99            self.params["training"]["model_save_path"], self.project_path
 100        )
 101        if load_epoch is None:
 102            model_file = sorted(os.listdir(model_path))[-1]
 103        else:
 104            model_files = os.listdir(model_path)
 105            if len(model_files) == 0:
 106                model_file = None
 107            else:
 108                epochs = [int(file[5:].split(".")[0]) for file in model_files]
 109                diffs = [np.abs(epoch - load_epoch) for epoch in epochs]
 110                argmin = np.argmin(diffs)
 111                model_file = model_files[argmin]
 112        model_file = os.path.join(model_path, model_file)
 113        return model_file
 114
 115    def dataset_name(self) -> str:
 116        """Get the dataset name.
 117
 118        Returns
 119        -------
 120        dataset_name : str
 121            the name of the dataset record
 122
 123        """
 124        data_path = correct_path(
 125            self.params["data"]["feature_save_path"], self.project_path
 126        )
 127        dataset_name = os.path.basename(data_path)
 128        return dataset_name
 129
 130    def split_file(self) -> str:
 131        """Get the split file.
 132
 133        Returns
 134        -------
 135        split_path : str
 136            the path to the split file
 137
 138        """
 139        return correct_path(self.params["training"]["split_path"], self.project_path)
 140
 141    def log_file(self) -> str:
 142        """Get the log file.
 143
 144        Returns
 145        -------
 146        log_path : str
 147            the path to the log file
 148
 149        """
 150        return correct_path(self.params["training"]["log_file"], self.project_path)
 151
 152    def split_info(self) -> Dict:
 153        """Get the train/test/val split information.
 154
 155        Returns
 156        -------
 157        split_info : dict
 158            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values
 159
 160        """
 161        val_frac = self.params["training"]["val_frac"]
 162        test_frac = self.params["training"]["test_frac"]
 163        partition_method = self.params["training"]["partition_method"]
 164        return {
 165            "val_frac": val_frac,
 166            "test_frac": test_frac,
 167            "partition_method": partition_method,
 168        }
 169
 170    def same_split_info(self, split_info: Dict) -> bool:
 171        """Check whether this episode has the same split information.
 172
 173        Parameters
 174        ----------
 175        split_info : dict
 176            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode
 177
 178        Returns
 179        -------
 180        result : bool
 181            if True, this episode has the same split information
 182
 183        """
 184        self_split_info = self.split_info()
 185        for k in ["val_frac", "test_frac", "partition_method"]:
 186            if self_split_info[k] != split_info[k]:
 187                return False
 188        return True
 189
 190    def get_metrics(self) -> List:
 191        """Get a list of tracked metrics.
 192
 193        Returns
 194        -------
 195        metrics : list
 196            a list of tracked metric names
 197
 198        """
 199        return self.params["general"]["metric_functions"]
 200
 201    def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray:
 202        """Get the metric log.
 203
 204        Parameters
 205        ----------
 206        mode : {'train', 'val'}
 207            the mode to get the log from
 208        metric_name : str
 209            the metric to get the log for (has to be one of the metric computed for this episode during training)
 210
 211        Returns
 212        -------
 213        log : np.ndarray
 214            the log of metric values (empty if the metric was not computed during training)
 215
 216        """
 217        metric_array = []
 218        with open(self.log_file()) as f:
 219            for line in f.readlines():
 220                if mode == "train" and line.startswith("[epoch"):
 221                    line = line.split("]: ")[1]
 222                elif mode == "val" and line.startswith("validation"):
 223                    line = line.split("validation: ")[1]
 224                else:
 225                    continue
 226                metrics = line.split(", ")
 227
 228                metric_ind = np.where(
 229                    np.array([m.split()[0] for m in metrics]) == metric_name
 230                )[0]
 231                if len(metric_ind):
 232                    name, value = metrics[metric_ind[0]].split()
 233                    metric_array.append(float(value))
 234                else:
 235                    metric_inds = [
 236                        m for m in metrics if m.split()[0].split("_")[0] == metric_name
 237                    ]
 238                    if len(metric_inds):
 239                        beh_metrics_avg = np.mean(
 240                            [float(m.split()[1]) for m in metric_inds]
 241                        )
 242                        metric_array.append(beh_metrics_avg)
 243
 244        return np.array(metric_array)
 245
 246    def get_epoch_list(self, mode) -> List:
 247        """Get a list of epoch indices.
 248
 249        Parameters
 250        ----------
 251        mode : {'train', 'val'}
 252            the mode to get the epoch list for
 253
 254        Returns
 255        -------
 256        epoch_list : list
 257            a list of int epoch indices
 258
 259        """
 260        epoch_list = []
 261        with open(self.log_file()) as f:
 262            for line in f.readlines():
 263                if line.startswith("[epoch"):
 264                    epoch = int(line[7:].split("]:")[0])
 265                    if mode == "train":
 266                        epoch_list.append(epoch)
 267                elif mode == "val":
 268                    epoch_list.append(epoch)
 269        return epoch_list
 270
 271    def get_metrics(self) -> List:
 272        """Get a list of metric names in the episode log.
 273
 274        Returns
 275        -------
 276        metrics : List
 277            a list of string metric names
 278
 279        """
 280        metrics = []
 281        with open(self.log_file()) as f:
 282            for line in f.readlines():
 283                if line.startswith("[epoch"):
 284                    line = line.split("]: ")[1]
 285                elif line.startswith("validation"):
 286                    line = line.split("validation: ")[1]
 287                else:
 288                    continue
 289                metric_logs = line.split(", ")
 290                for metric in metric_logs:
 291                    name, _ = metric.split()
 292                    metrics.append(name)
 293                break
 294        return metrics
 295
 296    def unfinished(self) -> bool:
 297        """Check whether this episode was interrupted.
 298
 299        Returns
 300        -------
 301        result : bool
 302            True if the number of epochs in the log file is smaller than in the parameters
 303
 304        """
 305        num_epoch_theor = self.params["training"]["num_epochs"]
 306        log_file = self.log_file()
 307        if not isinstance(log_file, str):
 308            return False
 309        if not os.path.exists(log_file):
 310            return True
 311        with open(self.log_file()) as f:
 312            num_epoch = 0
 313            val = False
 314            for line in f.readlines():
 315                num_epoch += 1
 316                if num_epoch == 2 and line.startswith("validation"):
 317                    val = True
 318            if val:
 319                num_epoch //= 2
 320        return num_epoch < num_epoch_theor
 321
 322    def get_class_ind(self, class_name: str) -> int:
 323        """Get the integer label from a class name.
 324
 325        Parameters
 326        ----------
 327        class_name : str
 328            the name of the class
 329
 330        Returns
 331        -------
 332        class_ind : int
 333            the integer label
 334
 335        """
 336        behaviors_dict = self.params["meta"]["behaviors_dict"]
 337        for k, v in behaviors_dict.items():
 338            if v == class_name:
 339                return k
 340        raise ValueError(
 341            f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})"
 342        )
 343
 344    def get_behaviors_dict(self) -> Dict:
 345        """Get behaviors dictionary in the episode.
 346
 347        Returns
 348        -------
 349        behaviors_dict : dict
 350            a dictionary with class indices as keys and labels as values
 351
 352        """
 353        behavior_dict = self.params["meta"]["behaviors_dict"]
 354        if isinstance(behavior_dict, str):
 355            behavior_dict = ast.literal_eval(behavior_dict)
 356
 357        return behavior_dict
 358
 359    def get_num_classes(self) -> int:
 360        """Get number of classes in episode.
 361
 362        Returns
 363        -------
 364        num_classes : int
 365            the number of classes
 366
 367        """
 368        return len(self.params["meta"]["behaviors_dict"])
 369
 370
 371class DecisionThresholds:
 372    """A class that saves and looks up tuned decision thresholds."""
 373
 374    def __init__(self, path: str) -> None:
 375        """Initialize the class.
 376
 377        Parameters
 378        ----------
 379        path : str
 380            the path to the pickled SavedRuns dataframe
 381
 382        """
 383        self.path = path
 384        self.data = pd.read_pickle(path)
 385
 386    def save_thresholds(
 387        self,
 388        episode_names: List,
 389        epochs: List,
 390        metric_name: str,
 391        metric_parameters: Dict,
 392        thresholds: List,
 393    ) -> None:
 394        """Add a new record.
 395
 396        Parameters
 397        ----------
 398        episode_names : list
 399            the names of the episodes
 400        epochs : int
 401            the epoch index list
 402        metric_name : str
 403            the name of the metric the thresholds were tuned on
 404        metric_parameters : dict
 405            the metric parameter dictionary
 406        thresholds : list
 407            a list of float decision thresholds
 408
 409        """
 410        episodes = set(zip(episode_names, epochs))
 411        for key in ["average", "threshold_value", "ignored_classes"]:
 412            if key in metric_parameters:
 413                metric_parameters.pop(key)
 414        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
 415        parameters["thresholds"] = thresholds
 416        parameters["episodes"] = episodes
 417        pars = {k: [v] for k, v in parameters.items()}
 418        self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0)
 419        self._save()
 420
 421    def find_thresholds(
 422        self,
 423        episode_names: List,
 424        epochs: List,
 425        metric_name: str,
 426        metric_parameters: Dict,
 427    ) -> Union[List, None]:
 428        """Find a record.
 429
 430        Parameters
 431        ----------
 432        episode_names : list
 433            the names of the episodes
 434        epochs : list
 435            the epoch index list
 436        metric_name : str
 437            the name of the metric the thresholds were tuned on
 438        metric_parameters : dict
 439            the metric parameter dictionary
 440
 441        Returns
 442        -------
 443        thresholds : list
 444            a list of float decision thresholds
 445
 446        """
 447        episodes = set(zip(episode_names, epochs))
 448        for key in ["average", "threshold_value", "ignored_classes"]:
 449            if key in metric_parameters:
 450                metric_parameters.pop(key)
 451        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
 452        parameters["episodes"] = episodes
 453        filter = deepcopy(parameters)
 454        for key, value in parameters.items():
 455            if value is None:
 456                filter.pop(key)
 457            elif key not in self.data.columns:
 458                return None
 459        data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)]
 460        if len(data) > 0:
 461            thresholds = data.iloc[0]["thresholds"]
 462            return thresholds
 463        else:
 464            return None
 465
 466    def _save(self) -> None:
 467        """Save the records."""
 468        self.data.copy().to_pickle(self.path)
 469
 470
 471class SavedRuns:
 472    """A class that manages operations with all episode (or prediction) records."""
 473
 474    def __init__(self, path: str, project_path: str) -> None:
 475        """Initialize the class.
 476
 477        Parameters
 478        ----------
 479        path : str
 480            the path to the pickled SavedRuns dataframe
 481        project_path : str
 482            the path to the project folder
 483
 484        """
 485        self.path = path
 486        self.project_path = project_path
 487        self.data = pd.read_pickle(path)
 488        self.data = _check_str_conversion(self.data)
 489
 490    def update(
 491        self,
 492        data: pd.DataFrame,
 493        data_path: str,
 494        annotation_path: str,
 495        name_map: Dict = None,
 496        force: bool = False,
 497    ) -> None:
 498        """Update with new data.
 499
 500        Parameters
 501        ----------
 502        data : pd.DataFrame
 503            the new dataframe
 504        data_path : str
 505            the new data path
 506        annotation_path : str
 507            the new annotation path
 508        name_map : dict, optional
 509            the name change dictionary; keys are old episode names and values are new episode names
 510        force : bool, default False
 511            replace existing episodes if `True`
 512
 513        """
 514        if name_map is None:
 515            name_map = {}
 516        data = data.rename(index=name_map)
 517        for episode in data.index:
 518            new_model = os.path.join(self.project_path, "results", "model", episode)
 519            data.loc[episode, ("training", "model_save_path")] = new_model
 520            new_log = os.path.join(
 521                self.project_path, "results", "logs", f"{episode}.txt"
 522            )
 523            data.loc[episode, ("training", "log_file")] = new_log
 524            old_split = data.loc[episode, ("training", "split_path")]
 525            if old_split is None:
 526                new_split = None
 527            else:
 528                new_split = os.path.join(
 529                    self.project_path, "results", "splits", os.path.basename(old_split)
 530                )
 531            data.loc[episode, ("training", "split_path")] = new_split
 532            data.loc[episode, ("data", "data_path")] = data_path
 533            data.loc[episode, ("data", "annotation_path")] = annotation_path
 534            if episode in self.data.index:
 535                if force:
 536                    self.data = self.data.drop(index=[episode])
 537                else:
 538                    raise RuntimeError(f"The {episode} episode name is already taken!")
 539        self.data = pd.concat([self.data, data])
 540        self._save()
 541
 542    def get_subset(self, episode_names: List) -> pd.DataFrame:
 543        """Get a subset of the raw metadata.
 544
 545        Parameters
 546        ----------
 547        episode_names : list
 548            a list of the episodes to include
 549
 550        Returns
 551        -------
 552        subset : pd.DataFrame
 553            the subset of the raw metadata
 554
 555        """
 556        for episode in episode_names:
 557            if episode not in self.data.index:
 558                raise ValueError(
 559                    f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records"
 560                )
 561        return self.data.loc[episode_names]
 562
 563    def get_saved_data_path(self, episode_name: str) -> str:
 564        """Get the `saved_data_path` parameter for the episode.
 565
 566        Parameters
 567        ----------
 568        episode_name : str
 569            the name of the episode
 570
 571        Returns
 572        -------
 573        saved_data_path : str
 574            the saved data path
 575
 576        """
 577        return self.data.loc[episode_name]["data"]["saved_data_path"]
 578
 579    def check_name_validity(self, episode_name: str) -> bool:
 580        """Check if an episode name already exists.
 581
 582        Parameters
 583        ----------
 584        episode_name : str
 585            the name to check
 586
 587        Returns
 588        -------
 589        result : bool
 590            True if the name can be used
 591
 592        """
 593        if episode_name in self.data.index:
 594            return False
 595        else:
 596            return True
 597
 598    def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None:
 599        """Update meta data with evaluation results.
 600
 601        Parameters
 602        ----------
 603        episode_name : str
 604            the name of the episode to update
 605        metrics : dict
 606            a dictionary of the metrics
 607
 608        """
 609        for key, value in metrics.items():
 610            self.data.loc[episode_name, ("results", key)] = value
 611        self._save()
 612
 613    def save_episode(
 614        self,
 615        episode_name: str,
 616        parameters: Dict,
 617        behaviors_dict: Dict,
 618        suppress_validation: bool = False,
 619        training_time: str = None,
 620    ) -> None:
 621        """Save a new run record.
 622
 623        Parameters
 624        ----------
 625        episode_name : str
 626            the name of the episode
 627        parameters : dict
 628            the parameters to save
 629        behaviors_dict : dict
 630            the dictionary of behaviors (keys are indices, values are names)
 631        suppress_validation : bool, optional False
 632            if True, existing episode with the same name will be overwritten
 633        training_time : str, optional
 634            the training time in '%H:%M:%S' format
 635
 636        """
 637        if not suppress_validation and episode_name in self.data.index:
 638            raise ValueError(f"Episode {episode_name} already exists!")
 639        pars = deepcopy(parameters)
 640        if "meta" not in pars:
 641            pars["meta"] = {
 642                "time": strftime("%Y-%m-%d %H:%M:%S", localtime()),
 643                "behaviors_dict": behaviors_dict,
 644            }
 645        else:
 646            pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime())
 647            pars["meta"]["behaviors_dict"] = behaviors_dict
 648        if training_time is not None:
 649            pars["meta"]["training_time"] = training_time
 650        if len(parameters.keys()) > 1:
 651            pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {})
 652            for metric_name in pars["general"]["metric_functions"]:
 653                pars[metric_name] = pars["metrics"].get(metric_name, {})
 654            if pars["general"].get("ssl", None) is not None:
 655                for ssl_name in pars["general"]["ssl"]:
 656                    pars[ssl_name] = pars["ssl"].get(ssl_name, {})
 657            for group_name in ["metrics", "ssl"]:
 658                if group_name in pars:
 659                    pars.pop(group_name)
 660        data = {
 661            (big_key, small_key): value
 662            for big_key, big_value in pars.items()
 663            for small_key, value in big_value.items()
 664        }
 665        list_keys = []
 666        with warnings.catch_warnings():
 667            warnings.filterwarnings("ignore", message="DataFrame is highly fragmented")
 668            for k, v in data.items():
 669                if k not in self.data.columns:
 670                    self.data[k] = np.nan
 671                if isinstance(v, list) and not isinstance(v, str):
 672                    list_keys.append(k)
 673            for k in list_keys:
 674                self.data[k] = self.data[k].astype(object)
 675            self.data.loc[episode_name] = data
 676        self._save()
 677
 678    def load_parameters(self, episode_name: str) -> Dict:
 679        """Load the task parameters from a record.
 680
 681        Parameters
 682        ----------
 683        episode_name : str
 684            the name of the episode to load
 685
 686        Returns
 687        -------
 688        parameters : dict
 689            the loaded task parameters
 690
 691        """
 692        parameters = defaultdict(lambda: defaultdict(lambda: {}))
 693        episode = self.data.loc[episode_name].dropna().to_dict()
 694        keys = ["data", "augmentations", "general", "training", "model", "features"]
 695        for key in episode:
 696            big_key, small_key = key
 697            if big_key in keys:
 698                parameters[big_key][small_key] = episode[key]
 699        # parameters = {k: dict(v) for k, v in parameters.items()}
 700        ssl_keys = parameters["general"].get("ssl", None)
 701        metric_keys = parameters["general"].get("metric_functions", None)
 702        loss_key = parameters["general"]["loss_function"]
 703        if ssl_keys is None:
 704            ssl_keys = []
 705        if metric_keys is None:
 706            metric_keys = []
 707        for key in episode:
 708            big_key, small_key = key
 709            if big_key in ssl_keys:
 710                parameters["ssl"][big_key][small_key] = episode[key]
 711            elif big_key in metric_keys:
 712                parameters["metrics"][big_key][small_key] = episode[key]
 713            elif big_key == "losses":
 714                parameters["losses"][loss_key][small_key] = episode[key]
 715        parameters = {k: dict(v) for k, v in parameters.items()}
 716        parameters["general"]["num_classes"] = Run(
 717            episode_name, self.project_path, params=self.data.loc[episode_name]
 718        ).get_num_classes()
 719        return parameters
 720
 721    def get_active_datasets(self) -> List:
 722        """Get a list of names of datasets that are used by unfinished episodes.
 723
 724        Returns
 725        -------
 726        active_datasets : list
 727            a list of dataset names used by unfinished episodes
 728
 729        """
 730        active_datasets = []
 731        for episode_name in self.unfinished_episodes():
 732            run = Run(
 733                episode_name, self.project_path, params=self.data.loc[episode_name]
 734            )
 735            active_datasets.append(run.dataset_name())
 736        return active_datasets
 737
 738    def list_episodes(
 739        self,
 740        episode_names: List = None,
 741        value_filter: str = "",
 742        display_parameters: List = None,
 743    ) -> pd.DataFrame:
 744        """Get a filtered pandas dataframe with episode metadata.
 745
 746        Parameters
 747        ----------
 748        episode_names : List
 749            a list of strings of episode names
 750        value_filter : str
 751            a string of filters to apply of this general structure:
 752            'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g.
 753            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic'
 754        display_parameters : List
 755            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
 756
 757        Returns
 758        -------
 759        pandas.DataFrame
 760            the filtered dataframe
 761
 762        """
 763        if episode_names is not None:
 764            data = deepcopy(self.data.loc[episode_names])
 765        else:
 766            data = deepcopy(self.data)
 767        if len(data) == 0:
 768            return pd.DataFrame()
 769        try:
 770            filters = value_filter.split(",")
 771            if filters == [""]:
 772                filters = []
 773            for f in filters:
 774                par_name, condition = f.split("::")
 775                group_name, par_name = par_name.split("/")
 776                sign, value = condition[0], condition[1:]
 777                if value[0] == "=":
 778                    sign += "="
 779                    value = value[1:]
 780                try:
 781                    value = float(value)
 782                except:
 783                    if value == "True":
 784                        value = True
 785                    elif value == "False":
 786                        value = False
 787                    elif value == "None":
 788                        value = None
 789                if value is None:
 790                    if sign == "=":
 791                        data = data[data[group_name][par_name].isna()]
 792                    elif sign == "!=":
 793                        data = data[~data[group_name][par_name].isna()]
 794                elif sign == ">":
 795                    data = data[data[group_name][par_name] > value]
 796                elif sign == ">=":
 797                    data = data[data[group_name][par_name] >= value]
 798                elif sign == "<":
 799                    data = data[data[group_name][par_name] < value]
 800                elif sign == "<=":
 801                    data = data[data[group_name][par_name] <= value]
 802                elif sign == "=":
 803                    data = data[data[group_name][par_name] == value]
 804                elif sign == "!=":
 805                    data = data[data[group_name][par_name] != value]
 806                else:
 807                    raise ValueError(
 808                        "Please use one of the signs: [>, <, >=, <=, =, !=]"
 809                    )
 810        except ValueError:
 811            raise ValueError(
 812                f"The {value_filter} filter is not valid, please use the following format:"
 813                f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', "
 814                f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'"
 815            )
 816        if display_parameters is not None:
 817            if type(display_parameters[0]) is str:
 818                display_parameters = [
 819                    (x.split("/")[0], x.split("/")[1]) for x in display_parameters
 820                ]
 821            display_parameters = [x for x in display_parameters if x in data.columns]
 822            data = data[display_parameters]
 823        return data
 824
 825    def rename_episode(self, episode_name, new_episode_name):
 826        """Rename an episode.
 827
 828        Parameters
 829        ----------
 830        episode_name : str
 831            the name of the episode to rename
 832        new_episode_name : str
 833            the new name of the episode
 834
 835        """
 836        if episode_name in self.data.index and new_episode_name not in self.data.index:
 837            self.data.loc[new_episode_name] = self.data.loc[episode_name]
 838            model_path = self.data.loc[new_episode_name, ("training", "model_path")]
 839            self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join(
 840                os.path.dirname(model_path), new_episode_name
 841            )
 842            log_path = self.data.loc[new_episode_name, ("training", "log_file")]
 843            self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join(
 844                os.path.dirname(log_path), f"{new_episode_name}.txt"
 845            )
 846            self.data = self.data.drop(index=episode_name)
 847            self._save()
 848        else:
 849            raise ValueError("The names are wrong")
 850
 851    def remove_episode(self, episode_name: str) -> None:
 852        """Remove all model, logs and metafile records related to an episode.
 853
 854        Parameters
 855        ----------
 856        episode_name : str
 857            the name of the episode to remove
 858
 859        """
 860        if episode_name in self.data.index:
 861            self.data = self.data.drop(index=episode_name)
 862            self._save()
 863
 864    def unfinished_episodes(self) -> List:
 865        """Get a list of unfinished episodes (currently running or interrupted).
 866
 867        Returns
 868        -------
 869        interrupted_episodes: List
 870            a list of string names of unfinished episodes in the records
 871
 872        """
 873        unfinished = []
 874        for name, params in self.data.iterrows():
 875            if Run(name, project_path=self.project_path, params=params).unfinished():
 876                unfinished.append(name)
 877        return unfinished
 878
 879    def update_episode_results(
 880        self,
 881        episode_name: str,
 882        logs: Tuple,
 883        training_time: str = None,
 884    ) -> None:
 885        """Add results to an episode record.
 886
 887        Parameters
 888        ----------
 889        episode_name : str
 890            the name of the episode to update
 891        logs : dict
 892            a log dictionary from task.train()
 893        training_time : str
 894            the training time
 895
 896        """
 897        metrics_log = logs[1]
 898        results = {}
 899        for key, value in metrics_log["val"].items():
 900            results[("results", key)] = value[-1]
 901        if training_time is not None:
 902            results[("meta", "training_time")] = training_time
 903        for k, v in results.items():
 904            self.data.loc[episode_name, k] = v
 905        self._save()
 906
 907    def get_runs(self, episode_name: str) -> List:
 908        """Get a list of runs with this episode name (episodes like `episode_name#0`).
 909
 910        Parameters
 911        ----------
 912        episode_name : str
 913            the name of the episode
 914
 915        Returns
 916        -------
 917        runs_list : List
 918            a list of string run names
 919
 920        """
 921        if episode_name is None:
 922            return []
 923        index = self.data.index
 924        runs_list = []
 925        for name in index:
 926            if name.startswith(episode_name):
 927                if "::" in name:
 928                    split = name.split("::")
 929                else:
 930                    split = name.split("#")
 931                if split[0] == episode_name:
 932                    if len(split) > 1 and split[-1].isnumeric() or len(split) == 1:
 933                        runs_list.append(name)
 934                elif name == episode_name:
 935                    runs_list.append(name)
 936        return runs_list
 937
 938    def _save(self):
 939        """Save the dataframe."""
 940        self.data.copy().to_pickle(self.path)
 941
 942
 943class Searches(SavedRuns):
 944    """A class that manages operations with search records."""
 945
 946    def save_search(
 947        self,
 948        search_name: str,
 949        parameters: Dict,
 950        n_trials: int,
 951        best_params: Dict,
 952        best_value: float,
 953        metric: str,
 954        search_space: Dict,
 955    ) -> None:
 956        """Save a new search record.
 957
 958        Parameters
 959        ----------
 960        search_name : str
 961            the name of the search to save
 962        parameters : dict
 963            the task parameters to save
 964        n_trials : int
 965            the number of trials in the search
 966        best_params : dict
 967            the best parameters dictionary
 968        best_value : float
 969            the best valie
 970        metric : str
 971            the name of the objective metric
 972        search_space : dict
 973            a dictionary representing the search space; of this general structure:
 974            {'group/param_name': ('float/int/float_log/int_log', start, end),
 975            'group/param_name': ('categorical', [choices])}, e.g.
 976            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
 977            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
 978
 979        """
 980        pars = deepcopy(parameters)
 981        pars["results"] = {"best_value": best_value, "best_params": best_params}
 982        pars["meta"] = {
 983            "objective": metric,
 984            "n_trials": n_trials,
 985            "search_space": search_space,
 986        }
 987        self.save_episode(search_name, pars, {})
 988
 989    def get_best_params_raw(self, search_name: str) -> Dict:
 990        """Get the raw dictionary of best parameters found by a search.
 991
 992        Parameters
 993        ----------
 994        search_name : str
 995            the name of the search
 996
 997        Returns
 998        -------
 999        best_params : dict
1000            a dictionary of the best parameters where the keys are in '{group}/{name}' format
1001
1002        """
1003        return self.data.loc[search_name]["results"]["best_params"]
1004
1005    def get_best_params(
1006        self,
1007        search_name: str,
1008        load_parameters: List = None,
1009        round_to_binary: List = None,
1010    ) -> Dict:
1011        """Get the best parameters from a search.
1012
1013        Parameters
1014        ----------
1015        search_name : str
1016            the name of the search
1017        load_parameters : List, optional
1018            a list of string names of the parameters to load (if not provided all parameters are loaded)
1019        round_to_binary : List, optional
1020            a list of string names of the loaded parameters that should be rounded to the nearest power of two
1021
1022        Returns
1023        -------
1024        best_params : dict
1025            a dictionary of the best parameters
1026
1027        """
1028        if round_to_binary is None:
1029            round_to_binary = []
1030        params = self.data.loc[search_name]["results"]["best_params"]
1031        if load_parameters is not None:
1032            params = {k: v for k, v in params.items() if k in load_parameters}
1033        for par_name in round_to_binary:
1034            if par_name not in params:
1035                continue
1036            if not isinstance(params[par_name], float) and not isinstance(
1037                params[par_name], int
1038            ):
1039                raise TypeError(
1040                    f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two"
1041                )
1042            i = 1
1043            while 2**i < params[par_name]:
1044                i += 1
1045            if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]:
1046                params[par_name] = 2 ** (i - 1)
1047            else:
1048                params[par_name] = 2**i
1049        res = defaultdict(lambda: defaultdict(lambda: {}))
1050        for k, v in params.items():
1051            big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:])
1052            if len(small_key.split("/")) == 1:
1053                res[big_key][small_key] = v
1054            else:
1055                group, key = small_key.split("/")
1056                res[big_key][group][key] = v
1057        model = self.data.loc[search_name]["general"]["model_name"]
1058        return res, model
1059
1060
1061class Suggestions(SavedRuns):
1062    """A class that manages operations with suggestion records."""
1063
1064    def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters):
1065        """Save a new suggestion record."""
1066        pars = deepcopy(parameters)
1067        pars["meta"] = meta_parameters
1068        super().save_episode(episode_name, pars, behaviors_dict=None)
1069
1070
1071class SavedStores:
1072    """A class that manages operations with saved dataset records."""
1073
1074    def __init__(self, path):
1075        """Initialize the class.
1076
1077        Parameters
1078        ----------
1079        path : str
1080            the path to the pickled SavedRuns dataframe
1081
1082        """
1083        self.path = path
1084        self.data = pd.read_pickle(path)
1085        self.skip_keys = [
1086            "feature_save_path",
1087            "saved_data_path",
1088            "real_lens",
1089            "recompute_annotation",
1090        ]
1091
1092    def clear(self) -> None:
1093        """Remove all datasets."""
1094        for dataset_name in self.data.index:
1095            self.remove_dataset(dataset_name)
1096
1097    def dataset_names(self) -> List:
1098        """Get a list of dataset names.
1099
1100        Returns
1101        -------
1102        dataset_names : List
1103            a list of string dataset names
1104
1105        """
1106        return list(self.data.index)
1107
1108    def remove(self, names: List) -> None:
1109        """Remove some datasets.
1110
1111        Parameters
1112        ----------
1113        names : List
1114            a list of string names of the datasets to delete
1115
1116        """
1117        for dataset_name in names:
1118            if dataset_name in self.data.index:
1119                self.remove_dataset(dataset_name)
1120
1121    def remove_dataset(self, dataset_name: str) -> None:
1122        """Remove a dataset record.
1123
1124        Parameters
1125        ----------
1126        dataset_name : str
1127            the name of the dataset to remove
1128
1129        """
1130        if dataset_name in self.data.index:
1131            self.data = self.data.drop(index=dataset_name)
1132            self._save()
1133
1134    def find_name(self, parameters: Dict) -> str:
1135        """Find a record that satisfies the parameters (if it exists).
1136
1137        Parameters
1138        ----------
1139        parameters : dict
1140            a dictionary of data parameters
1141
1142        Returns
1143        -------
1144        name : str
1145            the name of a record that has the same parameters (None if it does not exist; the earliest if there are
1146            several)
1147
1148        """
1149        filter = deepcopy(parameters)
1150        for key, value in parameters.items():
1151            if value is None or key in self.skip_keys:
1152                filter.pop(key)
1153            elif key not in self.data.columns:
1154                return None
1155        saved_annotation = self.data[
1156            (self.data[list(filter)] == pd.Series(filter)).all(axis=1)
1157        ]
1158        for i in range(len(saved_annotation)):
1159            ok = True
1160            for key in saved_annotation.columns:
1161                if key in self.skip_keys:
1162                    continue
1163                isnull = pd.isnull(saved_annotation.iloc[i][key])
1164                if not isinstance(isnull, bool):
1165                    isnull = False
1166                if key not in filter and not isnull:
1167                    ok = False
1168            if ok:
1169                name = saved_annotation.iloc[i].name
1170                return name
1171        return None
1172
1173    def save_store(self, episode_name: str, parameters: Dict) -> None:
1174        """Save a new saved dataset record.
1175
1176        Parameters
1177        ----------
1178        episode_name : str
1179            the name of the dataset
1180        parameters : dict
1181            a dictionary of data parameters
1182
1183        """
1184        pars = deepcopy(parameters)
1185        for k, v in parameters.items():
1186            if k not in self.data.columns:
1187                self.data[k] = np.nan
1188        if self.find_name(pars) is None:
1189            self.data.loc[episode_name] = pars
1190        self._save()
1191
1192    def _save(self):
1193        """Save the dataframe."""
1194        self.data.to_pickle(self.path)
1195
1196    def check_name_validity(self, store_name: str) -> bool:
1197        """Check if a store name already exists.
1198
1199        Parameters
1200        ----------
1201        store_name : str
1202            the name to check
1203
1204        Returns
1205        -------
1206        result : bool
1207            True if the name can be used
1208
1209        """
1210        if store_name in self.data.index:
1211            return False
1212        else:
1213            return True
1214
1215
1216def _check_str_conversion(params_origin: pd.DataFrame) -> pd.DataFrame:
1217    """Check if the parameters are in string format and convert them to the correct type."""
1218    params = deepcopy(params_origin)
1219
1220    # Early return if no conversion needed
1221    try:
1222        # Check if conversion is needed by testing a known column
1223        test_value = params[("general", "exclusive")]
1224        if isinstance(params, pd.DataFrame):
1225            # For DataFrame, check the first non-null value
1226            test_value = test_value.dropna().iloc[0] if not test_value.dropna().empty else test_value.iloc[0]
1227
1228        if not isinstance(test_value, str):
1229            return params
1230    except (KeyError, IndexError):
1231        # If the test column doesn't exist or is empty, return as-is
1232        return params
1233
1234    def safe_eval(value: str) -> any:
1235        """Safely evaluate a string value with fallback handling."""
1236        if not isinstance(value, str):
1237            return value
1238
1239        # Handle special case for odict_keys
1240        if value.startswith("set(odict_keys("):
1241            value = value.replace("set(odict_keys(", "set(").replace("))", ")")
1242        if value.startswith("ordereddict("):
1243            value = value.replace("ordered", "")
1244        try:
1245            result = eval(value)
1246            # Convert floats ending with .0 to integers
1247            if isinstance(result, float) and result.is_integer():
1248                return int(result)
1249            return result
1250        except (ValueError, SyntaxError, NameError, TypeError):
1251            # If eval fails, keep as string
1252            return value
1253
1254    def convert_stats(value) -> Dict:
1255        """Convert stats values, handling both strings and pandas Series."""
1256        if isinstance(value, pd.Series):
1257            # Handle pandas Series by applying conversion to each element
1258            return value.apply(lambda x: convert_stats(x) if isinstance(x, str) else x)
1259        elif isinstance(value, str):
1260            # Replace tensor(...) with torch.tensor(..., dtype=torch.float) for proper evaluation
1261            s_clean = re.sub(r'tensor\(\s*(\[.*?\])\s*\)', r'torch.tensor(\1, dtype=torch.float)', value, flags=re.DOTALL)
1262            # Evaluate the cleaned string safely
1263            try:
1264                import torch
1265                result = eval(s_clean, {"torch": torch})
1266                return result
1267            except Exception as e:
1268                print(f"Error while converting string: {e}")
1269                return value  # Return original value instead of None
1270        else:
1271            # Return non-string values as-is
1272            return value
1273
1274    if isinstance(params, pd.DataFrame):
1275        for col in params.columns:
1276            if isinstance(col, tuple) and len(col) == 2:  # MultiIndex column
1277                if col[1] == "stats":
1278                    # Special handling for stats columns - convert tensor strings
1279                    params[col] = params[col].apply(convert_stats)
1280                else:
1281                    # Regular string conversion for other columns
1282                    string_mask = params[col].apply(lambda x: isinstance(x, str))
1283                    if string_mask.any():
1284                        params.loc[string_mask, col] = params.loc[string_mask, col].apply(safe_eval)
1285    else:
1286        # Handle Series with MultiIndex
1287        for key, value in params.items():
1288            if isinstance(key, tuple) and len(key) == 2:  # MultiIndex
1289                if key[1] == "stats":
1290                    params[key] = convert_stats(value)
1291                else:
1292                    params[key] = safe_eval(value)
1293
1294    return params
class Run:
 25class Run:
 26    """A class that manages operations with a single episode record."""
 27
 28    def __init__(
 29        self,
 30        episode_name: str,
 31        project_path: str,
 32        meta_path: str = None,
 33        params: Dict = None,
 34    ):
 35        """Initialize the class.
 36
 37        Parameters
 38        ----------
 39        episode_name : str
 40            the name of the episode
 41        project_path : str
 42            the path to the project folder
 43        meta_path : str, optional
 44            the path to the pickled SavedRuns dataframe
 45        params : dict, optional
 46            alternative to meta_path: pre-loaded pandas Series of episode parameters
 47
 48        """
 49        self.name = episode_name
 50        self.project_path = project_path
 51        if meta_path is not None:
 52            try:
 53                self.params = pd.read_pickle(meta_path).loc[episode_name]
 54            except:
 55                raise ValueError(f"The {episode_name} episode does not exist!")
 56        elif params is not None:
 57            self.params = params
 58        else:
 59            raise ValueError("Either meta_path or params has to be not None")
 60        self.params = self._check_str_conversion()
 61
 62    def _check_str_conversion(self):
 63        """Check if the parameters are in string format and convert them to the correct type."""
 64        return _check_str_conversion(self.params)
 65
 66    def training_time(self) -> int:
 67        """Get the training time in seconds.
 68
 69        Returns
 70        -------
 71        training_time : int
 72            the training time in seconds
 73
 74        """
 75        time_str = self.params["meta"].get("training_time")
 76        try:
 77            if time_str is None or np.isnan(time_str):
 78                return np.nan
 79        except TypeError:
 80            pass
 81        h, m, s = time_str.split(":")
 82        seconds = int(h) * 3600 + int(m) * 60 + int(s)
 83        return seconds
 84
 85    def model_file(self, load_epoch: int = None) -> str:
 86        """Get a checkpoint file path.
 87
 88        Parameters
 89        ----------
 90        load_epoch : int, optional
 91            the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)
 92
 93        Returns
 94        -------
 95        checkpoint_path : str
 96            the path to the checkpoint
 97
 98        """
 99        model_path = correct_path(
100            self.params["training"]["model_save_path"], self.project_path
101        )
102        if load_epoch is None:
103            model_file = sorted(os.listdir(model_path))[-1]
104        else:
105            model_files = os.listdir(model_path)
106            if len(model_files) == 0:
107                model_file = None
108            else:
109                epochs = [int(file[5:].split(".")[0]) for file in model_files]
110                diffs = [np.abs(epoch - load_epoch) for epoch in epochs]
111                argmin = np.argmin(diffs)
112                model_file = model_files[argmin]
113        model_file = os.path.join(model_path, model_file)
114        return model_file
115
116    def dataset_name(self) -> str:
117        """Get the dataset name.
118
119        Returns
120        -------
121        dataset_name : str
122            the name of the dataset record
123
124        """
125        data_path = correct_path(
126            self.params["data"]["feature_save_path"], self.project_path
127        )
128        dataset_name = os.path.basename(data_path)
129        return dataset_name
130
131    def split_file(self) -> str:
132        """Get the split file.
133
134        Returns
135        -------
136        split_path : str
137            the path to the split file
138
139        """
140        return correct_path(self.params["training"]["split_path"], self.project_path)
141
142    def log_file(self) -> str:
143        """Get the log file.
144
145        Returns
146        -------
147        log_path : str
148            the path to the log file
149
150        """
151        return correct_path(self.params["training"]["log_file"], self.project_path)
152
153    def split_info(self) -> Dict:
154        """Get the train/test/val split information.
155
156        Returns
157        -------
158        split_info : dict
159            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values
160
161        """
162        val_frac = self.params["training"]["val_frac"]
163        test_frac = self.params["training"]["test_frac"]
164        partition_method = self.params["training"]["partition_method"]
165        return {
166            "val_frac": val_frac,
167            "test_frac": test_frac,
168            "partition_method": partition_method,
169        }
170
171    def same_split_info(self, split_info: Dict) -> bool:
172        """Check whether this episode has the same split information.
173
174        Parameters
175        ----------
176        split_info : dict
177            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode
178
179        Returns
180        -------
181        result : bool
182            if True, this episode has the same split information
183
184        """
185        self_split_info = self.split_info()
186        for k in ["val_frac", "test_frac", "partition_method"]:
187            if self_split_info[k] != split_info[k]:
188                return False
189        return True
190
191    def get_metrics(self) -> List:
192        """Get a list of tracked metrics.
193
194        Returns
195        -------
196        metrics : list
197            a list of tracked metric names
198
199        """
200        return self.params["general"]["metric_functions"]
201
202    def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray:
203        """Get the metric log.
204
205        Parameters
206        ----------
207        mode : {'train', 'val'}
208            the mode to get the log from
209        metric_name : str
210            the metric to get the log for (has to be one of the metric computed for this episode during training)
211
212        Returns
213        -------
214        log : np.ndarray
215            the log of metric values (empty if the metric was not computed during training)
216
217        """
218        metric_array = []
219        with open(self.log_file()) as f:
220            for line in f.readlines():
221                if mode == "train" and line.startswith("[epoch"):
222                    line = line.split("]: ")[1]
223                elif mode == "val" and line.startswith("validation"):
224                    line = line.split("validation: ")[1]
225                else:
226                    continue
227                metrics = line.split(", ")
228
229                metric_ind = np.where(
230                    np.array([m.split()[0] for m in metrics]) == metric_name
231                )[0]
232                if len(metric_ind):
233                    name, value = metrics[metric_ind[0]].split()
234                    metric_array.append(float(value))
235                else:
236                    metric_inds = [
237                        m for m in metrics if m.split()[0].split("_")[0] == metric_name
238                    ]
239                    if len(metric_inds):
240                        beh_metrics_avg = np.mean(
241                            [float(m.split()[1]) for m in metric_inds]
242                        )
243                        metric_array.append(beh_metrics_avg)
244
245        return np.array(metric_array)
246
247    def get_epoch_list(self, mode) -> List:
248        """Get a list of epoch indices.
249
250        Parameters
251        ----------
252        mode : {'train', 'val'}
253            the mode to get the epoch list for
254
255        Returns
256        -------
257        epoch_list : list
258            a list of int epoch indices
259
260        """
261        epoch_list = []
262        with open(self.log_file()) as f:
263            for line in f.readlines():
264                if line.startswith("[epoch"):
265                    epoch = int(line[7:].split("]:")[0])
266                    if mode == "train":
267                        epoch_list.append(epoch)
268                elif mode == "val":
269                    epoch_list.append(epoch)
270        return epoch_list
271
272    def get_metrics(self) -> List:
273        """Get a list of metric names in the episode log.
274
275        Returns
276        -------
277        metrics : List
278            a list of string metric names
279
280        """
281        metrics = []
282        with open(self.log_file()) as f:
283            for line in f.readlines():
284                if line.startswith("[epoch"):
285                    line = line.split("]: ")[1]
286                elif line.startswith("validation"):
287                    line = line.split("validation: ")[1]
288                else:
289                    continue
290                metric_logs = line.split(", ")
291                for metric in metric_logs:
292                    name, _ = metric.split()
293                    metrics.append(name)
294                break
295        return metrics
296
297    def unfinished(self) -> bool:
298        """Check whether this episode was interrupted.
299
300        Returns
301        -------
302        result : bool
303            True if the number of epochs in the log file is smaller than in the parameters
304
305        """
306        num_epoch_theor = self.params["training"]["num_epochs"]
307        log_file = self.log_file()
308        if not isinstance(log_file, str):
309            return False
310        if not os.path.exists(log_file):
311            return True
312        with open(self.log_file()) as f:
313            num_epoch = 0
314            val = False
315            for line in f.readlines():
316                num_epoch += 1
317                if num_epoch == 2 and line.startswith("validation"):
318                    val = True
319            if val:
320                num_epoch //= 2
321        return num_epoch < num_epoch_theor
322
323    def get_class_ind(self, class_name: str) -> int:
324        """Get the integer label from a class name.
325
326        Parameters
327        ----------
328        class_name : str
329            the name of the class
330
331        Returns
332        -------
333        class_ind : int
334            the integer label
335
336        """
337        behaviors_dict = self.params["meta"]["behaviors_dict"]
338        for k, v in behaviors_dict.items():
339            if v == class_name:
340                return k
341        raise ValueError(
342            f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})"
343        )
344
345    def get_behaviors_dict(self) -> Dict:
346        """Get behaviors dictionary in the episode.
347
348        Returns
349        -------
350        behaviors_dict : dict
351            a dictionary with class indices as keys and labels as values
352
353        """
354        behavior_dict = self.params["meta"]["behaviors_dict"]
355        if isinstance(behavior_dict, str):
356            behavior_dict = ast.literal_eval(behavior_dict)
357
358        return behavior_dict
359
360    def get_num_classes(self) -> int:
361        """Get number of classes in episode.
362
363        Returns
364        -------
365        num_classes : int
366            the number of classes
367
368        """
369        return len(self.params["meta"]["behaviors_dict"])

A class that manages operations with a single episode record.

Run( episode_name: str, project_path: str, meta_path: str = None, params: Dict = None)
28    def __init__(
29        self,
30        episode_name: str,
31        project_path: str,
32        meta_path: str = None,
33        params: Dict = None,
34    ):
35        """Initialize the class.
36
37        Parameters
38        ----------
39        episode_name : str
40            the name of the episode
41        project_path : str
42            the path to the project folder
43        meta_path : str, optional
44            the path to the pickled SavedRuns dataframe
45        params : dict, optional
46            alternative to meta_path: pre-loaded pandas Series of episode parameters
47
48        """
49        self.name = episode_name
50        self.project_path = project_path
51        if meta_path is not None:
52            try:
53                self.params = pd.read_pickle(meta_path).loc[episode_name]
54            except:
55                raise ValueError(f"The {episode_name} episode does not exist!")
56        elif params is not None:
57            self.params = params
58        else:
59            raise ValueError("Either meta_path or params has to be not None")
60        self.params = self._check_str_conversion()

Initialize the class.

Parameters

episode_name : str the name of the episode project_path : str the path to the project folder meta_path : str, optional the path to the pickled SavedRuns dataframe params : dict, optional alternative to meta_path: pre-loaded pandas Series of episode parameters

name
project_path
params
def training_time(self) -> int:
66    def training_time(self) -> int:
67        """Get the training time in seconds.
68
69        Returns
70        -------
71        training_time : int
72            the training time in seconds
73
74        """
75        time_str = self.params["meta"].get("training_time")
76        try:
77            if time_str is None or np.isnan(time_str):
78                return np.nan
79        except TypeError:
80            pass
81        h, m, s = time_str.split(":")
82        seconds = int(h) * 3600 + int(m) * 60 + int(s)
83        return seconds

Get the training time in seconds.

Returns

training_time : int the training time in seconds

def model_file(self, load_epoch: int = None) -> str:
 85    def model_file(self, load_epoch: int = None) -> str:
 86        """Get a checkpoint file path.
 87
 88        Parameters
 89        ----------
 90        load_epoch : int, optional
 91            the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)
 92
 93        Returns
 94        -------
 95        checkpoint_path : str
 96            the path to the checkpoint
 97
 98        """
 99        model_path = correct_path(
100            self.params["training"]["model_save_path"], self.project_path
101        )
102        if load_epoch is None:
103            model_file = sorted(os.listdir(model_path))[-1]
104        else:
105            model_files = os.listdir(model_path)
106            if len(model_files) == 0:
107                model_file = None
108            else:
109                epochs = [int(file[5:].split(".")[0]) for file in model_files]
110                diffs = [np.abs(epoch - load_epoch) for epoch in epochs]
111                argmin = np.argmin(diffs)
112                model_file = model_files[argmin]
113        model_file = os.path.join(model_path, model_file)
114        return model_file

Get a checkpoint file path.

Parameters

load_epoch : int, optional the epoch to load (the closest checkpoint will be chosen; if not given will be set to last)

Returns

checkpoint_path : str the path to the checkpoint

def dataset_name(self) -> str:
116    def dataset_name(self) -> str:
117        """Get the dataset name.
118
119        Returns
120        -------
121        dataset_name : str
122            the name of the dataset record
123
124        """
125        data_path = correct_path(
126            self.params["data"]["feature_save_path"], self.project_path
127        )
128        dataset_name = os.path.basename(data_path)
129        return dataset_name

Get the dataset name.

Returns

dataset_name : str the name of the dataset record

def split_file(self) -> str:
131    def split_file(self) -> str:
132        """Get the split file.
133
134        Returns
135        -------
136        split_path : str
137            the path to the split file
138
139        """
140        return correct_path(self.params["training"]["split_path"], self.project_path)

Get the split file.

Returns

split_path : str the path to the split file

def log_file(self) -> str:
142    def log_file(self) -> str:
143        """Get the log file.
144
145        Returns
146        -------
147        log_path : str
148            the path to the log file
149
150        """
151        return correct_path(self.params["training"]["log_file"], self.project_path)

Get the log file.

Returns

log_path : str the path to the log file

def split_info(self) -> Dict:
153    def split_info(self) -> Dict:
154        """Get the train/test/val split information.
155
156        Returns
157        -------
158        split_info : dict
159            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values
160
161        """
162        val_frac = self.params["training"]["val_frac"]
163        test_frac = self.params["training"]["test_frac"]
164        partition_method = self.params["training"]["partition_method"]
165        return {
166            "val_frac": val_frac,
167            "test_frac": test_frac,
168            "partition_method": partition_method,
169        }

Get the train/test/val split information.

Returns

split_info : dict a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values

def same_split_info(self, split_info: Dict) -> bool:
171    def same_split_info(self, split_info: Dict) -> bool:
172        """Check whether this episode has the same split information.
173
174        Parameters
175        ----------
176        split_info : dict
177            a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode
178
179        Returns
180        -------
181        result : bool
182            if True, this episode has the same split information
183
184        """
185        self_split_info = self.split_info()
186        for k in ["val_frac", "test_frac", "partition_method"]:
187            if self_split_info[k] != split_info[k]:
188                return False
189        return True

Check whether this episode has the same split information.

Parameters

split_info : dict a dictionary with [val_frac, test_frac, partition_method] keys and corresponding values from another episode

Returns

result : bool if True, this episode has the same split information

def get_metrics(self) -> List:
272    def get_metrics(self) -> List:
273        """Get a list of metric names in the episode log.
274
275        Returns
276        -------
277        metrics : List
278            a list of string metric names
279
280        """
281        metrics = []
282        with open(self.log_file()) as f:
283            for line in f.readlines():
284                if line.startswith("[epoch"):
285                    line = line.split("]: ")[1]
286                elif line.startswith("validation"):
287                    line = line.split("validation: ")[1]
288                else:
289                    continue
290                metric_logs = line.split(", ")
291                for metric in metric_logs:
292                    name, _ = metric.split()
293                    metrics.append(name)
294                break
295        return metrics

Get a list of metric names in the episode log.

Returns

metrics : List a list of string metric names

def get_metric_log(self, mode: str, metric_name: str) -> numpy.ndarray:
202    def get_metric_log(self, mode: str, metric_name: str) -> np.ndarray:
203        """Get the metric log.
204
205        Parameters
206        ----------
207        mode : {'train', 'val'}
208            the mode to get the log from
209        metric_name : str
210            the metric to get the log for (has to be one of the metric computed for this episode during training)
211
212        Returns
213        -------
214        log : np.ndarray
215            the log of metric values (empty if the metric was not computed during training)
216
217        """
218        metric_array = []
219        with open(self.log_file()) as f:
220            for line in f.readlines():
221                if mode == "train" and line.startswith("[epoch"):
222                    line = line.split("]: ")[1]
223                elif mode == "val" and line.startswith("validation"):
224                    line = line.split("validation: ")[1]
225                else:
226                    continue
227                metrics = line.split(", ")
228
229                metric_ind = np.where(
230                    np.array([m.split()[0] for m in metrics]) == metric_name
231                )[0]
232                if len(metric_ind):
233                    name, value = metrics[metric_ind[0]].split()
234                    metric_array.append(float(value))
235                else:
236                    metric_inds = [
237                        m for m in metrics if m.split()[0].split("_")[0] == metric_name
238                    ]
239                    if len(metric_inds):
240                        beh_metrics_avg = np.mean(
241                            [float(m.split()[1]) for m in metric_inds]
242                        )
243                        metric_array.append(beh_metrics_avg)
244
245        return np.array(metric_array)

Get the metric log.

Parameters

mode : {'train', 'val'} the mode to get the log from metric_name : str the metric to get the log for (has to be one of the metric computed for this episode during training)

Returns

log : np.ndarray the log of metric values (empty if the metric was not computed during training)

def get_epoch_list(self, mode) -> List:
247    def get_epoch_list(self, mode) -> List:
248        """Get a list of epoch indices.
249
250        Parameters
251        ----------
252        mode : {'train', 'val'}
253            the mode to get the epoch list for
254
255        Returns
256        -------
257        epoch_list : list
258            a list of int epoch indices
259
260        """
261        epoch_list = []
262        with open(self.log_file()) as f:
263            for line in f.readlines():
264                if line.startswith("[epoch"):
265                    epoch = int(line[7:].split("]:")[0])
266                    if mode == "train":
267                        epoch_list.append(epoch)
268                elif mode == "val":
269                    epoch_list.append(epoch)
270        return epoch_list

Get a list of epoch indices.

Parameters

mode : {'train', 'val'} the mode to get the epoch list for

Returns

epoch_list : list a list of int epoch indices

def unfinished(self) -> bool:
297    def unfinished(self) -> bool:
298        """Check whether this episode was interrupted.
299
300        Returns
301        -------
302        result : bool
303            True if the number of epochs in the log file is smaller than in the parameters
304
305        """
306        num_epoch_theor = self.params["training"]["num_epochs"]
307        log_file = self.log_file()
308        if not isinstance(log_file, str):
309            return False
310        if not os.path.exists(log_file):
311            return True
312        with open(self.log_file()) as f:
313            num_epoch = 0
314            val = False
315            for line in f.readlines():
316                num_epoch += 1
317                if num_epoch == 2 and line.startswith("validation"):
318                    val = True
319            if val:
320                num_epoch //= 2
321        return num_epoch < num_epoch_theor

Check whether this episode was interrupted.

Returns

result : bool True if the number of epochs in the log file is smaller than in the parameters

def get_class_ind(self, class_name: str) -> int:
323    def get_class_ind(self, class_name: str) -> int:
324        """Get the integer label from a class name.
325
326        Parameters
327        ----------
328        class_name : str
329            the name of the class
330
331        Returns
332        -------
333        class_ind : int
334            the integer label
335
336        """
337        behaviors_dict = self.params["meta"]["behaviors_dict"]
338        for k, v in behaviors_dict.items():
339            if v == class_name:
340                return k
341        raise ValueError(
342            f"The {class_name} class is not in classes predicted by {self.name} ({behaviors_dict})"
343        )

Get the integer label from a class name.

Parameters

class_name : str the name of the class

Returns

class_ind : int the integer label

def get_behaviors_dict(self) -> Dict:
345    def get_behaviors_dict(self) -> Dict:
346        """Get behaviors dictionary in the episode.
347
348        Returns
349        -------
350        behaviors_dict : dict
351            a dictionary with class indices as keys and labels as values
352
353        """
354        behavior_dict = self.params["meta"]["behaviors_dict"]
355        if isinstance(behavior_dict, str):
356            behavior_dict = ast.literal_eval(behavior_dict)
357
358        return behavior_dict

Get behaviors dictionary in the episode.

Returns

behaviors_dict : dict a dictionary with class indices as keys and labels as values

def get_num_classes(self) -> int:
360    def get_num_classes(self) -> int:
361        """Get number of classes in episode.
362
363        Returns
364        -------
365        num_classes : int
366            the number of classes
367
368        """
369        return len(self.params["meta"]["behaviors_dict"])

Get number of classes in episode.

Returns

num_classes : int the number of classes

class DecisionThresholds:
372class DecisionThresholds:
373    """A class that saves and looks up tuned decision thresholds."""
374
375    def __init__(self, path: str) -> None:
376        """Initialize the class.
377
378        Parameters
379        ----------
380        path : str
381            the path to the pickled SavedRuns dataframe
382
383        """
384        self.path = path
385        self.data = pd.read_pickle(path)
386
387    def save_thresholds(
388        self,
389        episode_names: List,
390        epochs: List,
391        metric_name: str,
392        metric_parameters: Dict,
393        thresholds: List,
394    ) -> None:
395        """Add a new record.
396
397        Parameters
398        ----------
399        episode_names : list
400            the names of the episodes
401        epochs : int
402            the epoch index list
403        metric_name : str
404            the name of the metric the thresholds were tuned on
405        metric_parameters : dict
406            the metric parameter dictionary
407        thresholds : list
408            a list of float decision thresholds
409
410        """
411        episodes = set(zip(episode_names, epochs))
412        for key in ["average", "threshold_value", "ignored_classes"]:
413            if key in metric_parameters:
414                metric_parameters.pop(key)
415        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
416        parameters["thresholds"] = thresholds
417        parameters["episodes"] = episodes
418        pars = {k: [v] for k, v in parameters.items()}
419        self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0)
420        self._save()
421
422    def find_thresholds(
423        self,
424        episode_names: List,
425        epochs: List,
426        metric_name: str,
427        metric_parameters: Dict,
428    ) -> Union[List, None]:
429        """Find a record.
430
431        Parameters
432        ----------
433        episode_names : list
434            the names of the episodes
435        epochs : list
436            the epoch index list
437        metric_name : str
438            the name of the metric the thresholds were tuned on
439        metric_parameters : dict
440            the metric parameter dictionary
441
442        Returns
443        -------
444        thresholds : list
445            a list of float decision thresholds
446
447        """
448        episodes = set(zip(episode_names, epochs))
449        for key in ["average", "threshold_value", "ignored_classes"]:
450            if key in metric_parameters:
451                metric_parameters.pop(key)
452        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
453        parameters["episodes"] = episodes
454        filter = deepcopy(parameters)
455        for key, value in parameters.items():
456            if value is None:
457                filter.pop(key)
458            elif key not in self.data.columns:
459                return None
460        data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)]
461        if len(data) > 0:
462            thresholds = data.iloc[0]["thresholds"]
463            return thresholds
464        else:
465            return None
466
467    def _save(self) -> None:
468        """Save the records."""
469        self.data.copy().to_pickle(self.path)

A class that saves and looks up tuned decision thresholds.

DecisionThresholds(path: str)
375    def __init__(self, path: str) -> None:
376        """Initialize the class.
377
378        Parameters
379        ----------
380        path : str
381            the path to the pickled SavedRuns dataframe
382
383        """
384        self.path = path
385        self.data = pd.read_pickle(path)

Initialize the class.

Parameters

path : str the path to the pickled SavedRuns dataframe

path
data
def save_thresholds( self, episode_names: List, epochs: List, metric_name: str, metric_parameters: Dict, thresholds: List) -> None:
387    def save_thresholds(
388        self,
389        episode_names: List,
390        epochs: List,
391        metric_name: str,
392        metric_parameters: Dict,
393        thresholds: List,
394    ) -> None:
395        """Add a new record.
396
397        Parameters
398        ----------
399        episode_names : list
400            the names of the episodes
401        epochs : int
402            the epoch index list
403        metric_name : str
404            the name of the metric the thresholds were tuned on
405        metric_parameters : dict
406            the metric parameter dictionary
407        thresholds : list
408            a list of float decision thresholds
409
410        """
411        episodes = set(zip(episode_names, epochs))
412        for key in ["average", "threshold_value", "ignored_classes"]:
413            if key in metric_parameters:
414                metric_parameters.pop(key)
415        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
416        parameters["thresholds"] = thresholds
417        parameters["episodes"] = episodes
418        pars = {k: [v] for k, v in parameters.items()}
419        self.data = pd.concat([self.data, pd.DataFrame.from_dict(pars)], axis=0)
420        self._save()

Add a new record.

Parameters

episode_names : list the names of the episodes epochs : int the epoch index list metric_name : str the name of the metric the thresholds were tuned on metric_parameters : dict the metric parameter dictionary thresholds : list a list of float decision thresholds

def find_thresholds( self, episode_names: List, epochs: List, metric_name: str, metric_parameters: Dict) -> Optional[List]:
422    def find_thresholds(
423        self,
424        episode_names: List,
425        epochs: List,
426        metric_name: str,
427        metric_parameters: Dict,
428    ) -> Union[List, None]:
429        """Find a record.
430
431        Parameters
432        ----------
433        episode_names : list
434            the names of the episodes
435        epochs : list
436            the epoch index list
437        metric_name : str
438            the name of the metric the thresholds were tuned on
439        metric_parameters : dict
440            the metric parameter dictionary
441
442        Returns
443        -------
444        thresholds : list
445            a list of float decision thresholds
446
447        """
448        episodes = set(zip(episode_names, epochs))
449        for key in ["average", "threshold_value", "ignored_classes"]:
450            if key in metric_parameters:
451                metric_parameters.pop(key)
452        parameters = {(metric_name, k): v for k, v in metric_parameters.items()}
453        parameters["episodes"] = episodes
454        filter = deepcopy(parameters)
455        for key, value in parameters.items():
456            if value is None:
457                filter.pop(key)
458            elif key not in self.data.columns:
459                return None
460        data = self.data[(self.data[list(filter)] == pd.Series(filter)).all(axis=1)]
461        if len(data) > 0:
462            thresholds = data.iloc[0]["thresholds"]
463            return thresholds
464        else:
465            return None

Find a record.

Parameters

episode_names : list the names of the episodes epochs : list the epoch index list metric_name : str the name of the metric the thresholds were tuned on metric_parameters : dict the metric parameter dictionary

Returns

thresholds : list a list of float decision thresholds

class SavedRuns:
472class SavedRuns:
473    """A class that manages operations with all episode (or prediction) records."""
474
475    def __init__(self, path: str, project_path: str) -> None:
476        """Initialize the class.
477
478        Parameters
479        ----------
480        path : str
481            the path to the pickled SavedRuns dataframe
482        project_path : str
483            the path to the project folder
484
485        """
486        self.path = path
487        self.project_path = project_path
488        self.data = pd.read_pickle(path)
489        self.data = _check_str_conversion(self.data)
490
491    def update(
492        self,
493        data: pd.DataFrame,
494        data_path: str,
495        annotation_path: str,
496        name_map: Dict = None,
497        force: bool = False,
498    ) -> None:
499        """Update with new data.
500
501        Parameters
502        ----------
503        data : pd.DataFrame
504            the new dataframe
505        data_path : str
506            the new data path
507        annotation_path : str
508            the new annotation path
509        name_map : dict, optional
510            the name change dictionary; keys are old episode names and values are new episode names
511        force : bool, default False
512            replace existing episodes if `True`
513
514        """
515        if name_map is None:
516            name_map = {}
517        data = data.rename(index=name_map)
518        for episode in data.index:
519            new_model = os.path.join(self.project_path, "results", "model", episode)
520            data.loc[episode, ("training", "model_save_path")] = new_model
521            new_log = os.path.join(
522                self.project_path, "results", "logs", f"{episode}.txt"
523            )
524            data.loc[episode, ("training", "log_file")] = new_log
525            old_split = data.loc[episode, ("training", "split_path")]
526            if old_split is None:
527                new_split = None
528            else:
529                new_split = os.path.join(
530                    self.project_path, "results", "splits", os.path.basename(old_split)
531                )
532            data.loc[episode, ("training", "split_path")] = new_split
533            data.loc[episode, ("data", "data_path")] = data_path
534            data.loc[episode, ("data", "annotation_path")] = annotation_path
535            if episode in self.data.index:
536                if force:
537                    self.data = self.data.drop(index=[episode])
538                else:
539                    raise RuntimeError(f"The {episode} episode name is already taken!")
540        self.data = pd.concat([self.data, data])
541        self._save()
542
543    def get_subset(self, episode_names: List) -> pd.DataFrame:
544        """Get a subset of the raw metadata.
545
546        Parameters
547        ----------
548        episode_names : list
549            a list of the episodes to include
550
551        Returns
552        -------
553        subset : pd.DataFrame
554            the subset of the raw metadata
555
556        """
557        for episode in episode_names:
558            if episode not in self.data.index:
559                raise ValueError(
560                    f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records"
561                )
562        return self.data.loc[episode_names]
563
564    def get_saved_data_path(self, episode_name: str) -> str:
565        """Get the `saved_data_path` parameter for the episode.
566
567        Parameters
568        ----------
569        episode_name : str
570            the name of the episode
571
572        Returns
573        -------
574        saved_data_path : str
575            the saved data path
576
577        """
578        return self.data.loc[episode_name]["data"]["saved_data_path"]
579
580    def check_name_validity(self, episode_name: str) -> bool:
581        """Check if an episode name already exists.
582
583        Parameters
584        ----------
585        episode_name : str
586            the name to check
587
588        Returns
589        -------
590        result : bool
591            True if the name can be used
592
593        """
594        if episode_name in self.data.index:
595            return False
596        else:
597            return True
598
599    def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None:
600        """Update meta data with evaluation results.
601
602        Parameters
603        ----------
604        episode_name : str
605            the name of the episode to update
606        metrics : dict
607            a dictionary of the metrics
608
609        """
610        for key, value in metrics.items():
611            self.data.loc[episode_name, ("results", key)] = value
612        self._save()
613
614    def save_episode(
615        self,
616        episode_name: str,
617        parameters: Dict,
618        behaviors_dict: Dict,
619        suppress_validation: bool = False,
620        training_time: str = None,
621    ) -> None:
622        """Save a new run record.
623
624        Parameters
625        ----------
626        episode_name : str
627            the name of the episode
628        parameters : dict
629            the parameters to save
630        behaviors_dict : dict
631            the dictionary of behaviors (keys are indices, values are names)
632        suppress_validation : bool, optional False
633            if True, existing episode with the same name will be overwritten
634        training_time : str, optional
635            the training time in '%H:%M:%S' format
636
637        """
638        if not suppress_validation and episode_name in self.data.index:
639            raise ValueError(f"Episode {episode_name} already exists!")
640        pars = deepcopy(parameters)
641        if "meta" not in pars:
642            pars["meta"] = {
643                "time": strftime("%Y-%m-%d %H:%M:%S", localtime()),
644                "behaviors_dict": behaviors_dict,
645            }
646        else:
647            pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime())
648            pars["meta"]["behaviors_dict"] = behaviors_dict
649        if training_time is not None:
650            pars["meta"]["training_time"] = training_time
651        if len(parameters.keys()) > 1:
652            pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {})
653            for metric_name in pars["general"]["metric_functions"]:
654                pars[metric_name] = pars["metrics"].get(metric_name, {})
655            if pars["general"].get("ssl", None) is not None:
656                for ssl_name in pars["general"]["ssl"]:
657                    pars[ssl_name] = pars["ssl"].get(ssl_name, {})
658            for group_name in ["metrics", "ssl"]:
659                if group_name in pars:
660                    pars.pop(group_name)
661        data = {
662            (big_key, small_key): value
663            for big_key, big_value in pars.items()
664            for small_key, value in big_value.items()
665        }
666        list_keys = []
667        with warnings.catch_warnings():
668            warnings.filterwarnings("ignore", message="DataFrame is highly fragmented")
669            for k, v in data.items():
670                if k not in self.data.columns:
671                    self.data[k] = np.nan
672                if isinstance(v, list) and not isinstance(v, str):
673                    list_keys.append(k)
674            for k in list_keys:
675                self.data[k] = self.data[k].astype(object)
676            self.data.loc[episode_name] = data
677        self._save()
678
679    def load_parameters(self, episode_name: str) -> Dict:
680        """Load the task parameters from a record.
681
682        Parameters
683        ----------
684        episode_name : str
685            the name of the episode to load
686
687        Returns
688        -------
689        parameters : dict
690            the loaded task parameters
691
692        """
693        parameters = defaultdict(lambda: defaultdict(lambda: {}))
694        episode = self.data.loc[episode_name].dropna().to_dict()
695        keys = ["data", "augmentations", "general", "training", "model", "features"]
696        for key in episode:
697            big_key, small_key = key
698            if big_key in keys:
699                parameters[big_key][small_key] = episode[key]
700        # parameters = {k: dict(v) for k, v in parameters.items()}
701        ssl_keys = parameters["general"].get("ssl", None)
702        metric_keys = parameters["general"].get("metric_functions", None)
703        loss_key = parameters["general"]["loss_function"]
704        if ssl_keys is None:
705            ssl_keys = []
706        if metric_keys is None:
707            metric_keys = []
708        for key in episode:
709            big_key, small_key = key
710            if big_key in ssl_keys:
711                parameters["ssl"][big_key][small_key] = episode[key]
712            elif big_key in metric_keys:
713                parameters["metrics"][big_key][small_key] = episode[key]
714            elif big_key == "losses":
715                parameters["losses"][loss_key][small_key] = episode[key]
716        parameters = {k: dict(v) for k, v in parameters.items()}
717        parameters["general"]["num_classes"] = Run(
718            episode_name, self.project_path, params=self.data.loc[episode_name]
719        ).get_num_classes()
720        return parameters
721
722    def get_active_datasets(self) -> List:
723        """Get a list of names of datasets that are used by unfinished episodes.
724
725        Returns
726        -------
727        active_datasets : list
728            a list of dataset names used by unfinished episodes
729
730        """
731        active_datasets = []
732        for episode_name in self.unfinished_episodes():
733            run = Run(
734                episode_name, self.project_path, params=self.data.loc[episode_name]
735            )
736            active_datasets.append(run.dataset_name())
737        return active_datasets
738
739    def list_episodes(
740        self,
741        episode_names: List = None,
742        value_filter: str = "",
743        display_parameters: List = None,
744    ) -> pd.DataFrame:
745        """Get a filtered pandas dataframe with episode metadata.
746
747        Parameters
748        ----------
749        episode_names : List
750            a list of strings of episode names
751        value_filter : str
752            a string of filters to apply of this general structure:
753            'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g.
754            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic'
755        display_parameters : List
756            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
757
758        Returns
759        -------
760        pandas.DataFrame
761            the filtered dataframe
762
763        """
764        if episode_names is not None:
765            data = deepcopy(self.data.loc[episode_names])
766        else:
767            data = deepcopy(self.data)
768        if len(data) == 0:
769            return pd.DataFrame()
770        try:
771            filters = value_filter.split(",")
772            if filters == [""]:
773                filters = []
774            for f in filters:
775                par_name, condition = f.split("::")
776                group_name, par_name = par_name.split("/")
777                sign, value = condition[0], condition[1:]
778                if value[0] == "=":
779                    sign += "="
780                    value = value[1:]
781                try:
782                    value = float(value)
783                except:
784                    if value == "True":
785                        value = True
786                    elif value == "False":
787                        value = False
788                    elif value == "None":
789                        value = None
790                if value is None:
791                    if sign == "=":
792                        data = data[data[group_name][par_name].isna()]
793                    elif sign == "!=":
794                        data = data[~data[group_name][par_name].isna()]
795                elif sign == ">":
796                    data = data[data[group_name][par_name] > value]
797                elif sign == ">=":
798                    data = data[data[group_name][par_name] >= value]
799                elif sign == "<":
800                    data = data[data[group_name][par_name] < value]
801                elif sign == "<=":
802                    data = data[data[group_name][par_name] <= value]
803                elif sign == "=":
804                    data = data[data[group_name][par_name] == value]
805                elif sign == "!=":
806                    data = data[data[group_name][par_name] != value]
807                else:
808                    raise ValueError(
809                        "Please use one of the signs: [>, <, >=, <=, =, !=]"
810                    )
811        except ValueError:
812            raise ValueError(
813                f"The {value_filter} filter is not valid, please use the following format:"
814                f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', "
815                f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'"
816            )
817        if display_parameters is not None:
818            if type(display_parameters[0]) is str:
819                display_parameters = [
820                    (x.split("/")[0], x.split("/")[1]) for x in display_parameters
821                ]
822            display_parameters = [x for x in display_parameters if x in data.columns]
823            data = data[display_parameters]
824        return data
825
826    def rename_episode(self, episode_name, new_episode_name):
827        """Rename an episode.
828
829        Parameters
830        ----------
831        episode_name : str
832            the name of the episode to rename
833        new_episode_name : str
834            the new name of the episode
835
836        """
837        if episode_name in self.data.index and new_episode_name not in self.data.index:
838            self.data.loc[new_episode_name] = self.data.loc[episode_name]
839            model_path = self.data.loc[new_episode_name, ("training", "model_path")]
840            self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join(
841                os.path.dirname(model_path), new_episode_name
842            )
843            log_path = self.data.loc[new_episode_name, ("training", "log_file")]
844            self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join(
845                os.path.dirname(log_path), f"{new_episode_name}.txt"
846            )
847            self.data = self.data.drop(index=episode_name)
848            self._save()
849        else:
850            raise ValueError("The names are wrong")
851
852    def remove_episode(self, episode_name: str) -> None:
853        """Remove all model, logs and metafile records related to an episode.
854
855        Parameters
856        ----------
857        episode_name : str
858            the name of the episode to remove
859
860        """
861        if episode_name in self.data.index:
862            self.data = self.data.drop(index=episode_name)
863            self._save()
864
865    def unfinished_episodes(self) -> List:
866        """Get a list of unfinished episodes (currently running or interrupted).
867
868        Returns
869        -------
870        interrupted_episodes: List
871            a list of string names of unfinished episodes in the records
872
873        """
874        unfinished = []
875        for name, params in self.data.iterrows():
876            if Run(name, project_path=self.project_path, params=params).unfinished():
877                unfinished.append(name)
878        return unfinished
879
880    def update_episode_results(
881        self,
882        episode_name: str,
883        logs: Tuple,
884        training_time: str = None,
885    ) -> None:
886        """Add results to an episode record.
887
888        Parameters
889        ----------
890        episode_name : str
891            the name of the episode to update
892        logs : dict
893            a log dictionary from task.train()
894        training_time : str
895            the training time
896
897        """
898        metrics_log = logs[1]
899        results = {}
900        for key, value in metrics_log["val"].items():
901            results[("results", key)] = value[-1]
902        if training_time is not None:
903            results[("meta", "training_time")] = training_time
904        for k, v in results.items():
905            self.data.loc[episode_name, k] = v
906        self._save()
907
908    def get_runs(self, episode_name: str) -> List:
909        """Get a list of runs with this episode name (episodes like `episode_name#0`).
910
911        Parameters
912        ----------
913        episode_name : str
914            the name of the episode
915
916        Returns
917        -------
918        runs_list : List
919            a list of string run names
920
921        """
922        if episode_name is None:
923            return []
924        index = self.data.index
925        runs_list = []
926        for name in index:
927            if name.startswith(episode_name):
928                if "::" in name:
929                    split = name.split("::")
930                else:
931                    split = name.split("#")
932                if split[0] == episode_name:
933                    if len(split) > 1 and split[-1].isnumeric() or len(split) == 1:
934                        runs_list.append(name)
935                elif name == episode_name:
936                    runs_list.append(name)
937        return runs_list
938
939    def _save(self):
940        """Save the dataframe."""
941        self.data.copy().to_pickle(self.path)

A class that manages operations with all episode (or prediction) records.

SavedRuns(path: str, project_path: str)
475    def __init__(self, path: str, project_path: str) -> None:
476        """Initialize the class.
477
478        Parameters
479        ----------
480        path : str
481            the path to the pickled SavedRuns dataframe
482        project_path : str
483            the path to the project folder
484
485        """
486        self.path = path
487        self.project_path = project_path
488        self.data = pd.read_pickle(path)
489        self.data = _check_str_conversion(self.data)

Initialize the class.

Parameters

path : str the path to the pickled SavedRuns dataframe project_path : str the path to the project folder

path
project_path
data
def update( self, data: pandas.core.frame.DataFrame, data_path: str, annotation_path: str, name_map: Dict = None, force: bool = False) -> None:
491    def update(
492        self,
493        data: pd.DataFrame,
494        data_path: str,
495        annotation_path: str,
496        name_map: Dict = None,
497        force: bool = False,
498    ) -> None:
499        """Update with new data.
500
501        Parameters
502        ----------
503        data : pd.DataFrame
504            the new dataframe
505        data_path : str
506            the new data path
507        annotation_path : str
508            the new annotation path
509        name_map : dict, optional
510            the name change dictionary; keys are old episode names and values are new episode names
511        force : bool, default False
512            replace existing episodes if `True`
513
514        """
515        if name_map is None:
516            name_map = {}
517        data = data.rename(index=name_map)
518        for episode in data.index:
519            new_model = os.path.join(self.project_path, "results", "model", episode)
520            data.loc[episode, ("training", "model_save_path")] = new_model
521            new_log = os.path.join(
522                self.project_path, "results", "logs", f"{episode}.txt"
523            )
524            data.loc[episode, ("training", "log_file")] = new_log
525            old_split = data.loc[episode, ("training", "split_path")]
526            if old_split is None:
527                new_split = None
528            else:
529                new_split = os.path.join(
530                    self.project_path, "results", "splits", os.path.basename(old_split)
531                )
532            data.loc[episode, ("training", "split_path")] = new_split
533            data.loc[episode, ("data", "data_path")] = data_path
534            data.loc[episode, ("data", "annotation_path")] = annotation_path
535            if episode in self.data.index:
536                if force:
537                    self.data = self.data.drop(index=[episode])
538                else:
539                    raise RuntimeError(f"The {episode} episode name is already taken!")
540        self.data = pd.concat([self.data, data])
541        self._save()

Update with new data.

Parameters

data : pd.DataFrame the new dataframe data_path : str the new data path annotation_path : str the new annotation path name_map : dict, optional the name change dictionary; keys are old episode names and values are new episode names force : bool, default False replace existing episodes if True

def get_subset(self, episode_names: List) -> pandas.core.frame.DataFrame:
543    def get_subset(self, episode_names: List) -> pd.DataFrame:
544        """Get a subset of the raw metadata.
545
546        Parameters
547        ----------
548        episode_names : list
549            a list of the episodes to include
550
551        Returns
552        -------
553        subset : pd.DataFrame
554            the subset of the raw metadata
555
556        """
557        for episode in episode_names:
558            if episode not in self.data.index:
559                raise ValueError(
560                    f"The {episode} episode is not in the records; please run `Project.list_episodes()` to explore the records"
561                )
562        return self.data.loc[episode_names]

Get a subset of the raw metadata.

Parameters

episode_names : list a list of the episodes to include

Returns

subset : pd.DataFrame the subset of the raw metadata

def get_saved_data_path(self, episode_name: str) -> str:
564    def get_saved_data_path(self, episode_name: str) -> str:
565        """Get the `saved_data_path` parameter for the episode.
566
567        Parameters
568        ----------
569        episode_name : str
570            the name of the episode
571
572        Returns
573        -------
574        saved_data_path : str
575            the saved data path
576
577        """
578        return self.data.loc[episode_name]["data"]["saved_data_path"]

Get the saved_data_path parameter for the episode.

Parameters

episode_name : str the name of the episode

Returns

saved_data_path : str the saved data path

def check_name_validity(self, episode_name: str) -> bool:
580    def check_name_validity(self, episode_name: str) -> bool:
581        """Check if an episode name already exists.
582
583        Parameters
584        ----------
585        episode_name : str
586            the name to check
587
588        Returns
589        -------
590        result : bool
591            True if the name can be used
592
593        """
594        if episode_name in self.data.index:
595            return False
596        else:
597            return True

Check if an episode name already exists.

Parameters

episode_name : str the name to check

Returns

result : bool True if the name can be used

def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None:
599    def update_episode_metrics(self, episode_name: str, metrics: Dict) -> None:
600        """Update meta data with evaluation results.
601
602        Parameters
603        ----------
604        episode_name : str
605            the name of the episode to update
606        metrics : dict
607            a dictionary of the metrics
608
609        """
610        for key, value in metrics.items():
611            self.data.loc[episode_name, ("results", key)] = value
612        self._save()

Update meta data with evaluation results.

Parameters

episode_name : str the name of the episode to update metrics : dict a dictionary of the metrics

def save_episode( self, episode_name: str, parameters: Dict, behaviors_dict: Dict, suppress_validation: bool = False, training_time: str = None) -> None:
614    def save_episode(
615        self,
616        episode_name: str,
617        parameters: Dict,
618        behaviors_dict: Dict,
619        suppress_validation: bool = False,
620        training_time: str = None,
621    ) -> None:
622        """Save a new run record.
623
624        Parameters
625        ----------
626        episode_name : str
627            the name of the episode
628        parameters : dict
629            the parameters to save
630        behaviors_dict : dict
631            the dictionary of behaviors (keys are indices, values are names)
632        suppress_validation : bool, optional False
633            if True, existing episode with the same name will be overwritten
634        training_time : str, optional
635            the training time in '%H:%M:%S' format
636
637        """
638        if not suppress_validation and episode_name in self.data.index:
639            raise ValueError(f"Episode {episode_name} already exists!")
640        pars = deepcopy(parameters)
641        if "meta" not in pars:
642            pars["meta"] = {
643                "time": strftime("%Y-%m-%d %H:%M:%S", localtime()),
644                "behaviors_dict": behaviors_dict,
645            }
646        else:
647            pars["meta"]["time"] = strftime("%Y-%m-%d %H:%M:%S", localtime())
648            pars["meta"]["behaviors_dict"] = behaviors_dict
649        if training_time is not None:
650            pars["meta"]["training_time"] = training_time
651        if len(parameters.keys()) > 1:
652            pars["losses"] = pars["losses"].get(pars["general"]["loss_function"], {})
653            for metric_name in pars["general"]["metric_functions"]:
654                pars[metric_name] = pars["metrics"].get(metric_name, {})
655            if pars["general"].get("ssl", None) is not None:
656                for ssl_name in pars["general"]["ssl"]:
657                    pars[ssl_name] = pars["ssl"].get(ssl_name, {})
658            for group_name in ["metrics", "ssl"]:
659                if group_name in pars:
660                    pars.pop(group_name)
661        data = {
662            (big_key, small_key): value
663            for big_key, big_value in pars.items()
664            for small_key, value in big_value.items()
665        }
666        list_keys = []
667        with warnings.catch_warnings():
668            warnings.filterwarnings("ignore", message="DataFrame is highly fragmented")
669            for k, v in data.items():
670                if k not in self.data.columns:
671                    self.data[k] = np.nan
672                if isinstance(v, list) and not isinstance(v, str):
673                    list_keys.append(k)
674            for k in list_keys:
675                self.data[k] = self.data[k].astype(object)
676            self.data.loc[episode_name] = data
677        self._save()

Save a new run record.

Parameters

episode_name : str the name of the episode parameters : dict the parameters to save behaviors_dict : dict the dictionary of behaviors (keys are indices, values are names) suppress_validation : bool, optional False if True, existing episode with the same name will be overwritten training_time : str, optional the training time in '%H:%M:%S' format

def load_parameters(self, episode_name: str) -> Dict:
679    def load_parameters(self, episode_name: str) -> Dict:
680        """Load the task parameters from a record.
681
682        Parameters
683        ----------
684        episode_name : str
685            the name of the episode to load
686
687        Returns
688        -------
689        parameters : dict
690            the loaded task parameters
691
692        """
693        parameters = defaultdict(lambda: defaultdict(lambda: {}))
694        episode = self.data.loc[episode_name].dropna().to_dict()
695        keys = ["data", "augmentations", "general", "training", "model", "features"]
696        for key in episode:
697            big_key, small_key = key
698            if big_key in keys:
699                parameters[big_key][small_key] = episode[key]
700        # parameters = {k: dict(v) for k, v in parameters.items()}
701        ssl_keys = parameters["general"].get("ssl", None)
702        metric_keys = parameters["general"].get("metric_functions", None)
703        loss_key = parameters["general"]["loss_function"]
704        if ssl_keys is None:
705            ssl_keys = []
706        if metric_keys is None:
707            metric_keys = []
708        for key in episode:
709            big_key, small_key = key
710            if big_key in ssl_keys:
711                parameters["ssl"][big_key][small_key] = episode[key]
712            elif big_key in metric_keys:
713                parameters["metrics"][big_key][small_key] = episode[key]
714            elif big_key == "losses":
715                parameters["losses"][loss_key][small_key] = episode[key]
716        parameters = {k: dict(v) for k, v in parameters.items()}
717        parameters["general"]["num_classes"] = Run(
718            episode_name, self.project_path, params=self.data.loc[episode_name]
719        ).get_num_classes()
720        return parameters

Load the task parameters from a record.

Parameters

episode_name : str the name of the episode to load

Returns

parameters : dict the loaded task parameters

def get_active_datasets(self) -> List:
722    def get_active_datasets(self) -> List:
723        """Get a list of names of datasets that are used by unfinished episodes.
724
725        Returns
726        -------
727        active_datasets : list
728            a list of dataset names used by unfinished episodes
729
730        """
731        active_datasets = []
732        for episode_name in self.unfinished_episodes():
733            run = Run(
734                episode_name, self.project_path, params=self.data.loc[episode_name]
735            )
736            active_datasets.append(run.dataset_name())
737        return active_datasets

Get a list of names of datasets that are used by unfinished episodes.

Returns

active_datasets : list a list of dataset names used by unfinished episodes

def list_episodes( self, episode_names: List = None, value_filter: str = '', display_parameters: List = None) -> pandas.core.frame.DataFrame:
739    def list_episodes(
740        self,
741        episode_names: List = None,
742        value_filter: str = "",
743        display_parameters: List = None,
744    ) -> pd.DataFrame:
745        """Get a filtered pandas dataframe with episode metadata.
746
747        Parameters
748        ----------
749        episode_names : List
750            a list of strings of episode names
751        value_filter : str
752            a string of filters to apply of this general structure:
753            'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g.
754            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic'
755        display_parameters : List
756            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
757
758        Returns
759        -------
760        pandas.DataFrame
761            the filtered dataframe
762
763        """
764        if episode_names is not None:
765            data = deepcopy(self.data.loc[episode_names])
766        else:
767            data = deepcopy(self.data)
768        if len(data) == 0:
769            return pd.DataFrame()
770        try:
771            filters = value_filter.split(",")
772            if filters == [""]:
773                filters = []
774            for f in filters:
775                par_name, condition = f.split("::")
776                group_name, par_name = par_name.split("/")
777                sign, value = condition[0], condition[1:]
778                if value[0] == "=":
779                    sign += "="
780                    value = value[1:]
781                try:
782                    value = float(value)
783                except:
784                    if value == "True":
785                        value = True
786                    elif value == "False":
787                        value = False
788                    elif value == "None":
789                        value = None
790                if value is None:
791                    if sign == "=":
792                        data = data[data[group_name][par_name].isna()]
793                    elif sign == "!=":
794                        data = data[~data[group_name][par_name].isna()]
795                elif sign == ">":
796                    data = data[data[group_name][par_name] > value]
797                elif sign == ">=":
798                    data = data[data[group_name][par_name] >= value]
799                elif sign == "<":
800                    data = data[data[group_name][par_name] < value]
801                elif sign == "<=":
802                    data = data[data[group_name][par_name] <= value]
803                elif sign == "=":
804                    data = data[data[group_name][par_name] == value]
805                elif sign == "!=":
806                    data = data[data[group_name][par_name] != value]
807                else:
808                    raise ValueError(
809                        "Please use one of the signs: [>, <, >=, <=, =, !=]"
810                    )
811        except ValueError:
812            raise ValueError(
813                f"The {value_filter} filter is not valid, please use the following format:"
814                f" 'group1/parameter1::[sign][value],group2/parameter2::[sign][value]', "
815                f"e.g. 'training/num_epochs::>=200,model/num_f_maps::=128,meta/time::>2022-06-01'"
816            )
817        if display_parameters is not None:
818            if type(display_parameters[0]) is str:
819                display_parameters = [
820                    (x.split("/")[0], x.split("/")[1]) for x in display_parameters
821                ]
822            display_parameters = [x for x in display_parameters if x in data.columns]
823            data = data[display_parameters]
824        return data

Get a filtered pandas dataframe with episode metadata.

Parameters

episode_names : List a list of strings of episode names value_filter : str a string of filters to apply of this general structure: 'group_name1/par_name1::(<>=)value1,group_name2/par_name2::(<>=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic' display_parameters : List list of parameters to display (e.g. ['data/overlap', 'results/recall'])

Returns

pandas.DataFrame the filtered dataframe

def rename_episode(self, episode_name, new_episode_name):
826    def rename_episode(self, episode_name, new_episode_name):
827        """Rename an episode.
828
829        Parameters
830        ----------
831        episode_name : str
832            the name of the episode to rename
833        new_episode_name : str
834            the new name of the episode
835
836        """
837        if episode_name in self.data.index and new_episode_name not in self.data.index:
838            self.data.loc[new_episode_name] = self.data.loc[episode_name]
839            model_path = self.data.loc[new_episode_name, ("training", "model_path")]
840            self.data.loc[new_episode_name, ("training", "model_path")] = os.path.join(
841                os.path.dirname(model_path), new_episode_name
842            )
843            log_path = self.data.loc[new_episode_name, ("training", "log_file")]
844            self.data.loc[new_episode_name, ("training", "log_file")] = os.path.join(
845                os.path.dirname(log_path), f"{new_episode_name}.txt"
846            )
847            self.data = self.data.drop(index=episode_name)
848            self._save()
849        else:
850            raise ValueError("The names are wrong")

Rename an episode.

Parameters

episode_name : str the name of the episode to rename new_episode_name : str the new name of the episode

def remove_episode(self, episode_name: str) -> None:
852    def remove_episode(self, episode_name: str) -> None:
853        """Remove all model, logs and metafile records related to an episode.
854
855        Parameters
856        ----------
857        episode_name : str
858            the name of the episode to remove
859
860        """
861        if episode_name in self.data.index:
862            self.data = self.data.drop(index=episode_name)
863            self._save()

Remove all model, logs and metafile records related to an episode.

Parameters

episode_name : str the name of the episode to remove

def unfinished_episodes(self) -> List:
865    def unfinished_episodes(self) -> List:
866        """Get a list of unfinished episodes (currently running or interrupted).
867
868        Returns
869        -------
870        interrupted_episodes: List
871            a list of string names of unfinished episodes in the records
872
873        """
874        unfinished = []
875        for name, params in self.data.iterrows():
876            if Run(name, project_path=self.project_path, params=params).unfinished():
877                unfinished.append(name)
878        return unfinished

Get a list of unfinished episodes (currently running or interrupted).

Returns

interrupted_episodes: List a list of string names of unfinished episodes in the records

def update_episode_results(self, episode_name: str, logs: Tuple, training_time: str = None) -> None:
880    def update_episode_results(
881        self,
882        episode_name: str,
883        logs: Tuple,
884        training_time: str = None,
885    ) -> None:
886        """Add results to an episode record.
887
888        Parameters
889        ----------
890        episode_name : str
891            the name of the episode to update
892        logs : dict
893            a log dictionary from task.train()
894        training_time : str
895            the training time
896
897        """
898        metrics_log = logs[1]
899        results = {}
900        for key, value in metrics_log["val"].items():
901            results[("results", key)] = value[-1]
902        if training_time is not None:
903            results[("meta", "training_time")] = training_time
904        for k, v in results.items():
905            self.data.loc[episode_name, k] = v
906        self._save()

Add results to an episode record.

Parameters

episode_name : str the name of the episode to update logs : dict a log dictionary from task.train() training_time : str the training time

def get_runs(self, episode_name: str) -> List:
908    def get_runs(self, episode_name: str) -> List:
909        """Get a list of runs with this episode name (episodes like `episode_name#0`).
910
911        Parameters
912        ----------
913        episode_name : str
914            the name of the episode
915
916        Returns
917        -------
918        runs_list : List
919            a list of string run names
920
921        """
922        if episode_name is None:
923            return []
924        index = self.data.index
925        runs_list = []
926        for name in index:
927            if name.startswith(episode_name):
928                if "::" in name:
929                    split = name.split("::")
930                else:
931                    split = name.split("#")
932                if split[0] == episode_name:
933                    if len(split) > 1 and split[-1].isnumeric() or len(split) == 1:
934                        runs_list.append(name)
935                elif name == episode_name:
936                    runs_list.append(name)
937        return runs_list

Get a list of runs with this episode name (episodes like episode_name#0).

Parameters

episode_name : str the name of the episode

Returns

runs_list : List a list of string run names

class Searches(SavedRuns):
 944class Searches(SavedRuns):
 945    """A class that manages operations with search records."""
 946
 947    def save_search(
 948        self,
 949        search_name: str,
 950        parameters: Dict,
 951        n_trials: int,
 952        best_params: Dict,
 953        best_value: float,
 954        metric: str,
 955        search_space: Dict,
 956    ) -> None:
 957        """Save a new search record.
 958
 959        Parameters
 960        ----------
 961        search_name : str
 962            the name of the search to save
 963        parameters : dict
 964            the task parameters to save
 965        n_trials : int
 966            the number of trials in the search
 967        best_params : dict
 968            the best parameters dictionary
 969        best_value : float
 970            the best valie
 971        metric : str
 972            the name of the objective metric
 973        search_space : dict
 974            a dictionary representing the search space; of this general structure:
 975            {'group/param_name': ('float/int/float_log/int_log', start, end),
 976            'group/param_name': ('categorical', [choices])}, e.g.
 977            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
 978            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
 979
 980        """
 981        pars = deepcopy(parameters)
 982        pars["results"] = {"best_value": best_value, "best_params": best_params}
 983        pars["meta"] = {
 984            "objective": metric,
 985            "n_trials": n_trials,
 986            "search_space": search_space,
 987        }
 988        self.save_episode(search_name, pars, {})
 989
 990    def get_best_params_raw(self, search_name: str) -> Dict:
 991        """Get the raw dictionary of best parameters found by a search.
 992
 993        Parameters
 994        ----------
 995        search_name : str
 996            the name of the search
 997
 998        Returns
 999        -------
1000        best_params : dict
1001            a dictionary of the best parameters where the keys are in '{group}/{name}' format
1002
1003        """
1004        return self.data.loc[search_name]["results"]["best_params"]
1005
1006    def get_best_params(
1007        self,
1008        search_name: str,
1009        load_parameters: List = None,
1010        round_to_binary: List = None,
1011    ) -> Dict:
1012        """Get the best parameters from a search.
1013
1014        Parameters
1015        ----------
1016        search_name : str
1017            the name of the search
1018        load_parameters : List, optional
1019            a list of string names of the parameters to load (if not provided all parameters are loaded)
1020        round_to_binary : List, optional
1021            a list of string names of the loaded parameters that should be rounded to the nearest power of two
1022
1023        Returns
1024        -------
1025        best_params : dict
1026            a dictionary of the best parameters
1027
1028        """
1029        if round_to_binary is None:
1030            round_to_binary = []
1031        params = self.data.loc[search_name]["results"]["best_params"]
1032        if load_parameters is not None:
1033            params = {k: v for k, v in params.items() if k in load_parameters}
1034        for par_name in round_to_binary:
1035            if par_name not in params:
1036                continue
1037            if not isinstance(params[par_name], float) and not isinstance(
1038                params[par_name], int
1039            ):
1040                raise TypeError(
1041                    f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two"
1042                )
1043            i = 1
1044            while 2**i < params[par_name]:
1045                i += 1
1046            if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]:
1047                params[par_name] = 2 ** (i - 1)
1048            else:
1049                params[par_name] = 2**i
1050        res = defaultdict(lambda: defaultdict(lambda: {}))
1051        for k, v in params.items():
1052            big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:])
1053            if len(small_key.split("/")) == 1:
1054                res[big_key][small_key] = v
1055            else:
1056                group, key = small_key.split("/")
1057                res[big_key][group][key] = v
1058        model = self.data.loc[search_name]["general"]["model_name"]
1059        return res, model

A class that manages operations with search records.

def get_best_params_raw(self, search_name: str) -> Dict:
 990    def get_best_params_raw(self, search_name: str) -> Dict:
 991        """Get the raw dictionary of best parameters found by a search.
 992
 993        Parameters
 994        ----------
 995        search_name : str
 996            the name of the search
 997
 998        Returns
 999        -------
1000        best_params : dict
1001            a dictionary of the best parameters where the keys are in '{group}/{name}' format
1002
1003        """
1004        return self.data.loc[search_name]["results"]["best_params"]

Get the raw dictionary of best parameters found by a search.

Parameters

search_name : str the name of the search

Returns

best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format

def get_best_params( self, search_name: str, load_parameters: List = None, round_to_binary: List = None) -> Dict:
1006    def get_best_params(
1007        self,
1008        search_name: str,
1009        load_parameters: List = None,
1010        round_to_binary: List = None,
1011    ) -> Dict:
1012        """Get the best parameters from a search.
1013
1014        Parameters
1015        ----------
1016        search_name : str
1017            the name of the search
1018        load_parameters : List, optional
1019            a list of string names of the parameters to load (if not provided all parameters are loaded)
1020        round_to_binary : List, optional
1021            a list of string names of the loaded parameters that should be rounded to the nearest power of two
1022
1023        Returns
1024        -------
1025        best_params : dict
1026            a dictionary of the best parameters
1027
1028        """
1029        if round_to_binary is None:
1030            round_to_binary = []
1031        params = self.data.loc[search_name]["results"]["best_params"]
1032        if load_parameters is not None:
1033            params = {k: v for k, v in params.items() if k in load_parameters}
1034        for par_name in round_to_binary:
1035            if par_name not in params:
1036                continue
1037            if not isinstance(params[par_name], float) and not isinstance(
1038                params[par_name], int
1039            ):
1040                raise TypeError(
1041                    f"Cannot round {par_name} parameter of type {type(par_name)} to a power of two"
1042                )
1043            i = 1
1044            while 2**i < params[par_name]:
1045                i += 1
1046            if params[par_name] - (2 ** (i - 1)) < (2**i) - params[par_name]:
1047                params[par_name] = 2 ** (i - 1)
1048            else:
1049                params[par_name] = 2**i
1050        res = defaultdict(lambda: defaultdict(lambda: {}))
1051        for k, v in params.items():
1052            big_key, small_key = k.split("/")[0], "/".join(k.split("/")[1:])
1053            if len(small_key.split("/")) == 1:
1054                res[big_key][small_key] = v
1055            else:
1056                group, key = small_key.split("/")
1057                res[big_key][group][key] = v
1058        model = self.data.loc[search_name]["general"]["model_name"]
1059        return res, model

Get the best parameters from a search.

Parameters

search_name : str the name of the search load_parameters : List, optional a list of string names of the parameters to load (if not provided all parameters are loaded) round_to_binary : List, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two

Returns

best_params : dict a dictionary of the best parameters

class Suggestions(SavedRuns):
1062class Suggestions(SavedRuns):
1063    """A class that manages operations with suggestion records."""
1064
1065    def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters):
1066        """Save a new suggestion record."""
1067        pars = deepcopy(parameters)
1068        pars["meta"] = meta_parameters
1069        super().save_episode(episode_name, pars, behaviors_dict=None)

A class that manages operations with suggestion records.

def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters):
1065    def save_suggestion(self, episode_name: str, parameters: Dict, meta_parameters):
1066        """Save a new suggestion record."""
1067        pars = deepcopy(parameters)
1068        pars["meta"] = meta_parameters
1069        super().save_episode(episode_name, pars, behaviors_dict=None)

Save a new suggestion record.

class SavedStores:
1072class SavedStores:
1073    """A class that manages operations with saved dataset records."""
1074
1075    def __init__(self, path):
1076        """Initialize the class.
1077
1078        Parameters
1079        ----------
1080        path : str
1081            the path to the pickled SavedRuns dataframe
1082
1083        """
1084        self.path = path
1085        self.data = pd.read_pickle(path)
1086        self.skip_keys = [
1087            "feature_save_path",
1088            "saved_data_path",
1089            "real_lens",
1090            "recompute_annotation",
1091        ]
1092
1093    def clear(self) -> None:
1094        """Remove all datasets."""
1095        for dataset_name in self.data.index:
1096            self.remove_dataset(dataset_name)
1097
1098    def dataset_names(self) -> List:
1099        """Get a list of dataset names.
1100
1101        Returns
1102        -------
1103        dataset_names : List
1104            a list of string dataset names
1105
1106        """
1107        return list(self.data.index)
1108
1109    def remove(self, names: List) -> None:
1110        """Remove some datasets.
1111
1112        Parameters
1113        ----------
1114        names : List
1115            a list of string names of the datasets to delete
1116
1117        """
1118        for dataset_name in names:
1119            if dataset_name in self.data.index:
1120                self.remove_dataset(dataset_name)
1121
1122    def remove_dataset(self, dataset_name: str) -> None:
1123        """Remove a dataset record.
1124
1125        Parameters
1126        ----------
1127        dataset_name : str
1128            the name of the dataset to remove
1129
1130        """
1131        if dataset_name in self.data.index:
1132            self.data = self.data.drop(index=dataset_name)
1133            self._save()
1134
1135    def find_name(self, parameters: Dict) -> str:
1136        """Find a record that satisfies the parameters (if it exists).
1137
1138        Parameters
1139        ----------
1140        parameters : dict
1141            a dictionary of data parameters
1142
1143        Returns
1144        -------
1145        name : str
1146            the name of a record that has the same parameters (None if it does not exist; the earliest if there are
1147            several)
1148
1149        """
1150        filter = deepcopy(parameters)
1151        for key, value in parameters.items():
1152            if value is None or key in self.skip_keys:
1153                filter.pop(key)
1154            elif key not in self.data.columns:
1155                return None
1156        saved_annotation = self.data[
1157            (self.data[list(filter)] == pd.Series(filter)).all(axis=1)
1158        ]
1159        for i in range(len(saved_annotation)):
1160            ok = True
1161            for key in saved_annotation.columns:
1162                if key in self.skip_keys:
1163                    continue
1164                isnull = pd.isnull(saved_annotation.iloc[i][key])
1165                if not isinstance(isnull, bool):
1166                    isnull = False
1167                if key not in filter and not isnull:
1168                    ok = False
1169            if ok:
1170                name = saved_annotation.iloc[i].name
1171                return name
1172        return None
1173
1174    def save_store(self, episode_name: str, parameters: Dict) -> None:
1175        """Save a new saved dataset record.
1176
1177        Parameters
1178        ----------
1179        episode_name : str
1180            the name of the dataset
1181        parameters : dict
1182            a dictionary of data parameters
1183
1184        """
1185        pars = deepcopy(parameters)
1186        for k, v in parameters.items():
1187            if k not in self.data.columns:
1188                self.data[k] = np.nan
1189        if self.find_name(pars) is None:
1190            self.data.loc[episode_name] = pars
1191        self._save()
1192
1193    def _save(self):
1194        """Save the dataframe."""
1195        self.data.to_pickle(self.path)
1196
1197    def check_name_validity(self, store_name: str) -> bool:
1198        """Check if a store name already exists.
1199
1200        Parameters
1201        ----------
1202        store_name : str
1203            the name to check
1204
1205        Returns
1206        -------
1207        result : bool
1208            True if the name can be used
1209
1210        """
1211        if store_name in self.data.index:
1212            return False
1213        else:
1214            return True

A class that manages operations with saved dataset records.

SavedStores(path)
1075    def __init__(self, path):
1076        """Initialize the class.
1077
1078        Parameters
1079        ----------
1080        path : str
1081            the path to the pickled SavedRuns dataframe
1082
1083        """
1084        self.path = path
1085        self.data = pd.read_pickle(path)
1086        self.skip_keys = [
1087            "feature_save_path",
1088            "saved_data_path",
1089            "real_lens",
1090            "recompute_annotation",
1091        ]

Initialize the class.

Parameters

path : str the path to the pickled SavedRuns dataframe

path
data
skip_keys
def clear(self) -> None:
1093    def clear(self) -> None:
1094        """Remove all datasets."""
1095        for dataset_name in self.data.index:
1096            self.remove_dataset(dataset_name)

Remove all datasets.

def dataset_names(self) -> List:
1098    def dataset_names(self) -> List:
1099        """Get a list of dataset names.
1100
1101        Returns
1102        -------
1103        dataset_names : List
1104            a list of string dataset names
1105
1106        """
1107        return list(self.data.index)

Get a list of dataset names.

Returns

dataset_names : List a list of string dataset names

def remove(self, names: List) -> None:
1109    def remove(self, names: List) -> None:
1110        """Remove some datasets.
1111
1112        Parameters
1113        ----------
1114        names : List
1115            a list of string names of the datasets to delete
1116
1117        """
1118        for dataset_name in names:
1119            if dataset_name in self.data.index:
1120                self.remove_dataset(dataset_name)

Remove some datasets.

Parameters

names : List a list of string names of the datasets to delete

def remove_dataset(self, dataset_name: str) -> None:
1122    def remove_dataset(self, dataset_name: str) -> None:
1123        """Remove a dataset record.
1124
1125        Parameters
1126        ----------
1127        dataset_name : str
1128            the name of the dataset to remove
1129
1130        """
1131        if dataset_name in self.data.index:
1132            self.data = self.data.drop(index=dataset_name)
1133            self._save()

Remove a dataset record.

Parameters

dataset_name : str the name of the dataset to remove

def find_name(self, parameters: Dict) -> str:
1135    def find_name(self, parameters: Dict) -> str:
1136        """Find a record that satisfies the parameters (if it exists).
1137
1138        Parameters
1139        ----------
1140        parameters : dict
1141            a dictionary of data parameters
1142
1143        Returns
1144        -------
1145        name : str
1146            the name of a record that has the same parameters (None if it does not exist; the earliest if there are
1147            several)
1148
1149        """
1150        filter = deepcopy(parameters)
1151        for key, value in parameters.items():
1152            if value is None or key in self.skip_keys:
1153                filter.pop(key)
1154            elif key not in self.data.columns:
1155                return None
1156        saved_annotation = self.data[
1157            (self.data[list(filter)] == pd.Series(filter)).all(axis=1)
1158        ]
1159        for i in range(len(saved_annotation)):
1160            ok = True
1161            for key in saved_annotation.columns:
1162                if key in self.skip_keys:
1163                    continue
1164                isnull = pd.isnull(saved_annotation.iloc[i][key])
1165                if not isinstance(isnull, bool):
1166                    isnull = False
1167                if key not in filter and not isnull:
1168                    ok = False
1169            if ok:
1170                name = saved_annotation.iloc[i].name
1171                return name
1172        return None

Find a record that satisfies the parameters (if it exists).

Parameters

parameters : dict a dictionary of data parameters

Returns

name : str the name of a record that has the same parameters (None if it does not exist; the earliest if there are several)

def save_store(self, episode_name: str, parameters: Dict) -> None:
1174    def save_store(self, episode_name: str, parameters: Dict) -> None:
1175        """Save a new saved dataset record.
1176
1177        Parameters
1178        ----------
1179        episode_name : str
1180            the name of the dataset
1181        parameters : dict
1182            a dictionary of data parameters
1183
1184        """
1185        pars = deepcopy(parameters)
1186        for k, v in parameters.items():
1187            if k not in self.data.columns:
1188                self.data[k] = np.nan
1189        if self.find_name(pars) is None:
1190            self.data.loc[episode_name] = pars
1191        self._save()

Save a new saved dataset record.

Parameters

episode_name : str the name of the dataset parameters : dict a dictionary of data parameters

def check_name_validity(self, store_name: str) -> bool:
1197    def check_name_validity(self, store_name: str) -> bool:
1198        """Check if a store name already exists.
1199
1200        Parameters
1201        ----------
1202        store_name : str
1203            the name to check
1204
1205        Returns
1206        -------
1207        result : bool
1208            True if the name can be used
1209
1210        """
1211        if store_name in self.data.index:
1212            return False
1213        else:
1214            return True

Check if a store name already exists.

Parameters

store_name : str the name to check

Returns

result : bool True if the name can be used