dlc2action.project.project

Project interface

   1#
   2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
   5#
   6"""
   7Project interface
   8"""
   9import copy
  10from email.policy import default
  11import os
  12from re import search
  13from typing import Dict, List, Tuple, Union, Set, Iterable, Any, Optional
  14import shutil
  15
  16from numpy import ndarray
  17from ruamel.yaml import YAML
  18import pickle
  19import pandas as pd
  20from dlc2action.data.dataset import BehaviorDataset
  21from dlc2action.utils import apply_threshold
  22from collections.abc import Mapping
  23from collections import defaultdict
  24
  25from dlc2action.task.task_dispatcher import TaskDispatcher
  26import warnings
  27from copy import deepcopy, copy
  28import time
  29import numpy as np
  30from matplotlib import pyplot as plt
  31from matplotlib import cm
  32from itertools import product
  33from collections.abc import Iterable
  34import optuna
  35import plotly
  36import torch
  37from pathlib import Path
  38from dlc2action import options, __version__
  39from ruamel.yaml.comments import CommentedMap, CommentedSet
  40from tqdm import tqdm
  41from dlc2action.project.meta import (
  42    Searches,
  43    SavedStores,
  44    Run,
  45    SavedRuns,
  46    DecisionThresholds,
  47)
  48
  49
  50class Project:
  51    """
  52    A class to create and maintain the project files + keep track of experiments
  53    """
  54
  55    def __init__(
  56        self,
  57        name: str,
  58        data_type: str = None,
  59        annotation_type: str = "none",
  60        projects_path: str = None,
  61        data_path: Union[str, List] = None,
  62        annotation_path: Union[str, List] = None,
  63        copy: bool = False,
  64    ) -> None:
  65        """
  66        Parameters
  67        ----------
  68        name : str
  69            name of the project
  70        data_type : str, optional
  71            data type (run Project.data_types() to see available options; has to be provided if the project is being
  72            created)
  73        annotation_type : str, default 'none'
  74            annotation type (run Project.annotation_types() to see available options)
  75        projects_path : str, optional
  76            path to the projects folder (is filled with ~/DLC2Action by default)
  77        data_path : str, optional
  78            path to the folder containing input files for the project (has to be provided if the project is being
  79            created)
  80        annotation_path : str, optional
  81            path to the folder containing annotation files for the project
  82        copy : bool, default False
  83            if True, the files from annotation_path and data_path will be copied to the projects folder;
  84            otherwise they will be moved
  85        """
  86
  87        if projects_path is None:
  88            projects_path = os.path.join(str(Path.home()), "DLC2Action")
  89        if not os.path.exists(projects_path):
  90            os.mkdir(projects_path)
  91        self.project_path = os.path.join(projects_path, name)
  92        self.name = name
  93        self.data_type = data_type
  94        self.annotation_type = annotation_type
  95        self.data_path = data_path
  96        self.annotation_path = annotation_path
  97        if not os.path.exists(self.project_path):
  98            if data_type is None:
  99                raise ValueError(
 100                    "The data_type parameter is necessary when creating a new project!"
 101                )
 102            self._initialize_project(
 103                data_type, annotation_type, data_path, annotation_path, copy
 104            )
 105        else:
 106            self.annotation_type, self.data_type = self._read_types()
 107            if data_type != self.data_type and data_type is not None:
 108                raise ValueError(
 109                    f"The project has already been initialized with data_type={self.data_type}!"
 110                )
 111            if annotation_type != self.annotation_type and annotation_type != "none":
 112                raise ValueError(
 113                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
 114                )
 115            self.annotation_path, data_path = self._read_paths()
 116            if self.data_path is None:
 117                self.data_path = data_path
 118            # if data_path != self.data_path and data_path is not None:
 119            #     raise ValueError(
 120            #         f"The project has already been initialized with data_path={self.data_path}!"
 121            #     )
 122            if annotation_path != self.annotation_path and annotation_path is not None:
 123                raise ValueError(
 124                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
 125                )
 126        self._update_configs()
 127
 128    def _aggregate_predictions(
 129        self,
 130        prediction_name: str,
 131        episode_names: List,
 132        load_epochs: List = None,
 133        parameters_update: Dict = None,
 134        data_path: str = None,
 135        file_paths: Set = None,
 136        mode: str = "all",
 137        augment_n: int = 0,
 138        evaluate: bool = False,
 139        task: TaskDispatcher = None,
 140        embedding: bool = False,
 141    ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]:
 142        """
 143        Generate a prediction
 144        """
 145
 146        if load_epochs is None:
 147            load_epochs = [None for _ in episode_names]
 148        prediction = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0)))
 149        cnt = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0)))
 150        behs = set(self.get_behavior_dictionary(episode_names[0]).values())
 151        if not all(
 152            [
 153                set(self.get_behavior_dictionary(x).values()) == behs
 154                for x in episode_names
 155            ]
 156        ):
 157            raise ValueError(f"The behavior sets are different in {episode_names}")
 158        behaviors = set()
 159        for i, episode_name in enumerate(episode_names):
 160            task, parameters, data_mode, new_pred, _ = self._make_prediction(
 161                prediction_name,
 162                episode_names=[episode_name],
 163                load_epochs=[load_epochs[i]],
 164                parameters_update=parameters_update,
 165                data_path=data_path,
 166                file_paths=file_paths,
 167                mode=mode,
 168                augment_n=augment_n,
 169                evaluate=evaluate,
 170                task=task,
 171                embedding=embedding,
 172            )
 173            new_pred = task.dataset(data_mode).generate_full_length_prediction(new_pred)
 174            beh_dict = task.behaviors_dict()
 175            for video_id, video_values in new_pred.items():
 176                for clip_id, clip_prediction in video_values.items():
 177                    for beh_i in range(clip_prediction.shape[0]):
 178                        prediction[video_id][clip_id][
 179                            beh_dict[beh_i]
 180                        ] += clip_prediction[beh_i, :].unsqueeze(0)
 181                        cnt[video_id][clip_id][beh_dict[beh_i]] += 1
 182                        behaviors.add(beh_dict[beh_i])
 183        output = defaultdict(lambda: {})
 184        # behaviors = sorted(behaviors)
 185        behavior_indices = sorted(
 186            [x for x in task.behaviors_dict().keys() if x != -100]
 187        )
 188        behaviors = [task.behaviors_dict()[key] for key in behavior_indices]
 189        for video_id, video_values in prediction.items():
 190            for clip_id, clip_values in video_values.items():
 191                pred = torch.cat(
 192                    [
 193                        clip_values[beh] / cnt[video_id][clip_id][beh]
 194                        for beh in behaviors
 195                    ],
 196                    0,
 197                )
 198                output[video_id][clip_id] = pred
 199        return task, parameters, data_mode, dict(output), None
 200
 201    def _make_prediction(
 202        self,
 203        prediction_name: str,
 204        episode_names: List,
 205        load_epochs: List = None,
 206        parameters_update: Dict = None,
 207        data_path: str = None,
 208        file_paths: Set = None,
 209        mode: str = "all",
 210        augment_n: int = 0,
 211        evaluate: bool = False,
 212        task: TaskDispatcher = None,
 213        embedding: bool = False,
 214    ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]:
 215        """
 216        Generate a prediction
 217        """
 218
 219        names = []
 220        epochs = []
 221        if load_epochs is None:
 222            load_epochs = [None for _ in episode_names]
 223        if len(load_epochs) != len(episode_names):
 224            raise ValueError(
 225                f"The length of load_epochs and the length of episode_names should be the same!"
 226            )
 227        for i, episode_name in enumerate(episode_names):
 228            names += self._episodes().get_runs(episode_name)
 229            epochs.append(load_epochs[i])
 230        if len(names) == 0:
 231            warnings.warn(f"None of the episodes {episode_names} exist!")
 232            names = [None]
 233        episodes = self._episodes()
 234        lengths = [
 235            episodes.load_parameters(name)["general"]["len_segment"] for name in names
 236        ]
 237        overlaps = [
 238            episodes.load_parameters(name)["general"]["overlap"] for name in names
 239        ]
 240        if not all([x == lengths[0] for x in lengths]):
 241            raise ValueError(f"Episodes {episode_names} have different segment lengths")
 242        if not all([x == overlaps[0] for x in overlaps]):
 243            raise ValueError(f"Episodes {episode_names} have different overlaps")
 244        load_epochs = epochs
 245        prediction = None
 246        decision_thresholds = None
 247        time_total = 0
 248        behavior_dicts = [
 249            self.get_behavior_dictionary(episode_name) for episode_name in names
 250        ]
 251        if not all(
 252            [
 253                set(d.values()) == set(behavior_dicts[0].values())
 254                for d in behavior_dicts[1:]
 255            ]
 256        ):
 257            raise ValueError(
 258                f"Episodes {episode_names} have different sets of behaviors!"
 259            )
 260        behavior_indices = [x for x in behavior_dicts[0].keys() if x != -100]
 261        behaviors = [behavior_dicts[0][i] for i in behavior_indices]
 262        cnt = defaultdict(lambda: 0)
 263        behavior_probs = defaultdict(lambda: 0)
 264        for episode_name, load_epoch, behavior_dict in zip(
 265            names, load_epochs, behavior_dicts
 266        ):
 267            print(f"episode {episode_name}")
 268            task, parameters, data_mode = self._make_task_prediction(
 269                prediction_name=prediction_name,
 270                load_episode=episode_name,
 271                parameters_update=parameters_update,
 272                load_epoch=load_epoch,
 273                data_path=data_path,
 274                mode=mode,
 275                file_paths=file_paths,
 276                task=task,
 277                decision_thresholds=decision_thresholds,
 278            )
 279            behavior_indices_cur = [x for x in behavior_dict.keys() if x != -100]
 280            behaviors_cur = [behavior_dict[i] for i in behavior_indices_cur]
 281            # data_mode = "train" if mode == "all" else mode
 282            time_start = time.time()
 283            new_pred = task.predict(
 284                data_mode,
 285                raw_output=True,
 286                apply_primary_function=True,
 287                augment_n=augment_n,
 288                embedding=embedding,
 289            )
 290            for j, beh in enumerate(behaviors_cur):
 291                cnt[beh] += 1
 292                behavior_probs[beh] += new_pred[:, j, :].unsqueeze(1)
 293            # indices = [
 294            #     behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1])
 295            # ]
 296            # new_pred = new_pred[:, indices, :]
 297            time_end = time.time()
 298            time_total += time_end - time_start
 299            if evaluate:
 300                _, metrics = task.evaluate_prediction(new_pred, data=data_mode)
 301                if mode == "val":
 302                    self._update_episode_metrics(episode_name, metrics)
 303            # if prediction is None:
 304            #     prediction = new_pred
 305            # else:
 306            #     prediction += new_pred
 307            print("\n")
 308        prediction = torch.cat([behavior_probs[beh] / cnt[beh] for beh in behaviors], 1)
 309        hours = int(time_total // 3600)
 310        time_total -= hours * 3600
 311        minutes = int(time_total // 60)
 312        time_total -= minutes * 60
 313        seconds = int(time_total)
 314        inference_time = f"{hours}:{minutes:02}:{seconds:02}"
 315        # prediction /= len(names)
 316        return task, parameters, data_mode, prediction, inference_time
 317
 318    def _make_task_prediction(
 319        self,
 320        prediction_name: str,
 321        load_episode: str = None,
 322        parameters_update: Dict = None,
 323        load_epoch: int = None,
 324        data_path: str = None,
 325        mode: str = "val",
 326        file_paths: Set = None,
 327        decision_thresholds: List = None,
 328        task: TaskDispatcher = None,
 329    ) -> Tuple[TaskDispatcher, Dict, str]:
 330        """
 331        Make a `TaskDispatcher` object that will be used to generate a prediction
 332        """
 333
 334        if parameters_update is None:
 335            parameters_update = {}
 336        parameters_update_second = {}
 337        if mode == "all" or data_path is not None or file_paths is not None:
 338            parameters_update_second["training"] = {
 339                "val_frac": 0,
 340                "test_frac": 0,
 341                "partition_method": "random",
 342                "save_split": False,
 343                "split_path": None,
 344            }
 345            mode = "train"
 346        if decision_thresholds is not None:
 347            if (
 348                len(decision_thresholds)
 349                == self._episode(load_episode).get_num_classes()
 350            ):
 351                parameters_update_second["general"] = {
 352                    "threshold_value": decision_thresholds
 353                }
 354            else:
 355                raise ValueError(
 356                    f"The length of the decision thresholds {decision_thresholds} "
 357                    f"must be equal to the length of the behaviors dictionary "
 358                    f"{self._episode(load_episode).get_behaviors_dict()}"
 359                )
 360        data_param_update = {}
 361        if data_path is not None:
 362            data_param_update = {"data_path": data_path}
 363        if file_paths is not None:
 364            data_param_update = {"data_path": None, "file_paths": file_paths}
 365        parameters_update = self._update(parameters_update, {"data": data_param_update})
 366        if data_path is not None or file_paths is not None:
 367            general_update = {
 368                "annotation_type": "none",
 369                "only_load_annotated": False,
 370            }
 371        else:
 372            general_update = {}
 373        parameters_update = self._update(parameters_update, {"general": general_update})
 374        task, parameters = self._make_task(
 375            episode_name=prediction_name,
 376            load_episode=load_episode,
 377            parameters_update=parameters_update,
 378            parameters_update_second=parameters_update_second,
 379            load_epoch=load_epoch,
 380            purpose="prediction",
 381            task=task,
 382            behaviors=self.get_behavior_dictionary(load_episode),
 383        )
 384        # if data_path is not None or file_paths is not None:
 385        #     print('SETTING')
 386        #     task.set_behaviors(self.get_behavior_dictionary(load_episode))
 387        if mode is None:
 388            if task.exists("test"):
 389                mode = "test"
 390            elif task.exists("val"):
 391                mode = "val"
 392            else:
 393                mode = "train"
 394        return task, parameters, mode
 395
 396    def _make_task_training(
 397        self,
 398        episode_name: str,
 399        load_episode: str = None,
 400        parameters_update: Dict = None,
 401        load_epoch: int = None,
 402        load_search: str = None,
 403        load_parameters: list = None,
 404        round_to_binary: list = None,
 405        load_strict: bool = True,
 406        continuing: bool = False,
 407        task: TaskDispatcher = None,
 408        mask_name: str = None,
 409        throwaway: bool = False,
 410    ) -> Tuple[TaskDispatcher, Dict, str]:
 411        """
 412        Make a `TaskDispatcher` object that will be used to generate a prediction
 413        """
 414
 415        if parameters_update is None:
 416            parameters_update = {}
 417        if continuing:
 418            purpose = "continuing"
 419        else:
 420            purpose = "training"
 421        if mask_name is not None:
 422            mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle")
 423        parameters_update_second = {"data": {"real_lens": mask_name}}
 424        if throwaway:
 425            parameters_update = self._update(
 426                parameters_update, {"training": {"normalize": False, "device": "cpu"}}
 427            )
 428        return self._make_task(
 429            episode_name,
 430            load_episode,
 431            parameters_update,
 432            parameters_update_second,
 433            load_epoch,
 434            load_search,
 435            load_parameters,
 436            round_to_binary,
 437            purpose,
 438            task,
 439            load_strict=load_strict,
 440        )
 441
 442    def _make_parameters(
 443        self,
 444        episode_name: str,
 445        load_episode: str = None,
 446        parameters_update: Dict = None,
 447        parameters_update_second: Dict = None,
 448        load_epoch: int = None,
 449        load_search: str = None,
 450        load_parameters: list = None,
 451        round_to_binary: list = None,
 452        purpose: str = "train",
 453        load_strict: bool = True,
 454    ):
 455        """
 456        Construct a parameters dictionary
 457        """
 458
 459        if parameters_update is None:
 460            parameters_update = {}
 461        pars_update = deepcopy(parameters_update)
 462        if parameters_update_second is None:
 463            parameters_update_second = {}
 464        if purpose == "prediction" and "model" in pars_update.keys():
 465            raise ValueError("Cannot change model parameters after training!")
 466        if purpose in ["continuing", "prediction"] and load_episode is not None:
 467            read_parameters = self._read_parameters()
 468            parameters = self._episodes().load_parameters(load_episode)
 469            parameters["metrics"] = self._update(
 470                read_parameters["metrics"], parameters["metrics"]
 471            )
 472            parameters["ssl"] = self._update(
 473                read_parameters["ssl"], parameters.get("ssl", {})
 474            )
 475        else:
 476            parameters = self._read_parameters()
 477        if "model" in pars_update:
 478            model_params = pars_update.pop("model")
 479        else:
 480            model_params = None
 481        if "features" in pars_update:
 482            feat_params = pars_update.pop("features")
 483        else:
 484            feat_params = None
 485        if "augmentations" in pars_update:
 486            aug_params = pars_update.pop("augmentations")
 487        else:
 488            aug_params = None
 489        parameters = self._update(parameters, pars_update)
 490        if pars_update.get("general", {}).get("model_name") is not None:
 491            model_name = parameters["general"]["model_name"]
 492            parameters["model"] = self._open_yaml(
 493                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
 494            )
 495        if pars_update.get("general", {}).get("feature_extraction") is not None:
 496            feat_name = parameters["general"]["feature_extraction"]
 497            parameters["features"] = self._open_yaml(
 498                os.path.join(
 499                    self.project_path, "config", "features", f"{feat_name}.yaml"
 500                )
 501            )
 502            aug_name = options.extractor_to_transformer[
 503                parameters["general"]["feature_extraction"]
 504            ]
 505            parameters["augmentations"] = self._open_yaml(
 506                os.path.join(
 507                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
 508                )
 509            )
 510        if model_params is not None:
 511            parameters["model"] = self._update(parameters["model"], model_params)
 512        if feat_params is not None:
 513            parameters["features"] = self._update(parameters["features"], feat_params)
 514        if aug_params is not None:
 515            parameters["augmentations"] = self._update(
 516                parameters["augmentations"], aug_params
 517            )
 518        if load_search is not None:
 519            parameters = self._update_with_search(
 520                parameters, load_search, load_parameters, round_to_binary
 521            )
 522        parameters = self._fill(
 523            parameters,
 524            episode_name,
 525            load_episode,
 526            load_epoch=load_epoch,
 527            load_strict=load_strict,
 528            only_load_model=(purpose != "continuing"),
 529            continuing=(purpose in ["prediction", "continuing"]),
 530            enforce_split_parameters=(purpose == "prediction"),
 531        )
 532        parameters = self._update(parameters, parameters_update_second)
 533        return parameters
 534
 535    def _make_task(
 536        self,
 537        episode_name: str,
 538        load_episode: str = None,
 539        parameters_update: Dict = None,
 540        parameters_update_second: Dict = None,
 541        load_epoch: int = None,
 542        load_search: str = None,
 543        load_parameters: list = None,
 544        round_to_binary: list = None,
 545        purpose: str = "train",
 546        task: TaskDispatcher = None,
 547        load_strict: bool = True,
 548        behaviors: Dict = None,
 549    ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]:
 550        """
 551        Make a `TaskDispatcher` object
 552
 553        The task parameters are read from the config files and then updated with the
 554        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 555        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 556        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 557        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 558        data parameters are used.
 559
 560        Parameters
 561        ----------
 562        episode_name : str
 563            the name of the episode
 564        load_episode : str, optional
 565            the (previously run) episode name to load the model from
 566        parameters_update : dict, optional
 567            the dictionary used to update the parameters from the config
 568        parameters_update_second : dict, optional
 569            the dictionary used to update the parameters after the automatic fill-out
 570        load_epoch : int, optional
 571            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 572        load_search : str, optional
 573            the hyperparameter search result to load
 574        load_parameters : list, optional
 575            a list of string names of the parameters to load from load_search (if not provided, all parameters
 576            are loaded)
 577        round_to_binary : list, optional
 578            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 579        purpose : {"train", "continuing", "prediction"}
 580            the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing
 581            the training of an interrupted episode, `"prediction"` for generating a prediction)
 582        task : TaskDispatcher, optional
 583            a pre-existing task; if provided, the method will update the task instead of creating a new one
 584            (this might save time, mainly on dataset loading)
 585
 586        Returns
 587        -------
 588        task : TaskDispatcher
 589            the `TaskDispatcher` instance
 590        parameters : dict
 591            the parameters dictionary that describes the task
 592        """
 593
 594        parameters = self._make_parameters(
 595            episode_name,
 596            load_episode,
 597            parameters_update,
 598            parameters_update_second,
 599            load_epoch,
 600            load_search,
 601            load_parameters,
 602            round_to_binary,
 603            purpose,
 604            load_strict=load_strict,
 605        )
 606        if parameters["data"].get("annotation_type", "none") == "none":
 607            parameters = self._update(
 608                parameters, {"data": {"behavior_dictionary": behaviors}}
 609            )
 610        if task is None:
 611            task = TaskDispatcher(parameters)
 612        else:
 613            task.update_task(parameters)
 614        self._save_stores(parameters)
 615        return task, parameters
 616
 617    def run_episode(
 618        self,
 619        episode_name: str,
 620        load_episode: str = None,
 621        parameters_update: Dict = None,
 622        task: TaskDispatcher = None,
 623        load_epoch: int = None,
 624        load_search: str = None,
 625        load_parameters: list = None,
 626        round_to_binary: list = None,
 627        load_strict: bool = True,
 628        n_seeds: int = 1,
 629        force: bool = False,
 630        suppress_name_check: bool = False,
 631        remove_saved_features: bool = False,
 632        mask_name: str = None,
 633        autostop_metric: str = None,
 634        autostop_interval: int = 50,
 635        autostop_threshold: float = 0.001,
 636        loading_bar: bool = False,
 637        trial: Tuple = None,
 638    ) -> TaskDispatcher:
 639        """
 640        Run an episode
 641
 642        The task parameters are read from the config files and then updated with the
 643        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 644        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 645        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 646        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 647        data parameters are used.
 648
 649        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 650        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 651        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 652        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 653
 654        Parameters
 655        ----------
 656        episode_name : str
 657            the episode name
 658        load_episode : str, optional
 659            the (previously run) episode name to load the model from; if the episode has multiple runs,
 660            the new episode will have the same number of runs, each starting with one of the pre-trained models
 661        parameters_update : dict, optional
 662            the dictionary used to update the parameters from the config files
 663        task : TaskDispatcher, optional
 664            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
 665            a new instance)
 666        load_epoch : int, optional
 667            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 668        load_search : str, optional
 669            the hyperparameter search result to load
 670        load_parameters : list, optional
 671            a list of string names of the parameters to load from load_search (if not provided, all parameters
 672            are loaded)
 673        round_to_binary : list, optional
 674            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 675        load_strict : bool, default True
 676            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
 677            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
 678        n_seeds : int, default 1
 679            the number of runs to perform with different random seeds; if `n_seeds > 1`, the episodes will be named
 680            `episode_name::seed_index`, e.g. `test_episode::0` and `test_episode::1`
 681        force : bool, default False
 682            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
 683        suppress_name_check : bool, default False
 684            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 685            why they are usually forbidden)
 686        remove_saved_features : bool, default False
 687            if `True`, the dataset will be deleted after training
 688        mask_name : str, optional
 689            the name of the real_lens to apply
 690        autostop_interval : int, default 50
 691            the number of epochs to average the autostop metric over
 692        autostop_threshold : float, default 0.001
 693            the autostop difference threshold
 694        autostop_metric : str, optional
 695            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 696        """
 697
 698        if type(n_seeds) is not int or n_seeds < 1:
 699            raise ValueError(
 700                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
 701            )
 702        if n_seeds > 1 and mask_name is not None:
 703            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
 704        self._check_episode_validity(
 705            episode_name, allow_doublecolon=suppress_name_check, force=force
 706        )
 707        load_runs = self._episodes().get_runs(load_episode)
 708        if len(load_runs) > 1:
 709            task = self.run_episodes(
 710                episode_names=[
 711                    f'{episode_name}::{run.split("::")[-1]}' for run in load_runs
 712                ],
 713                load_episodes=load_runs,
 714                parameters_updates=[parameters_update for _ in load_runs],
 715                load_epochs=[load_epoch for _ in load_runs],
 716                load_searches=[load_search for _ in load_runs],
 717                load_parameters=[load_parameters for _ in load_runs],
 718                round_to_binary=[round_to_binary for _ in load_runs],
 719                load_strict=[load_strict for _ in load_runs],
 720                suppress_name_check=True,
 721                force=force,
 722                remove_saved_features=False,
 723            )
 724            if remove_saved_features:
 725                self._remove_stores(
 726                    {
 727                        "general": task.general_parameters,
 728                        "data": task.data_parameters,
 729                        "features": task.feature_parameters,
 730                    }
 731                )
 732            if n_seeds > 1:
 733                warnings.warn(
 734                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
 735                )
 736        elif n_seeds > 1:
 737            self.run_episodes(
 738                episode_names=[f"{episode_name}::{i}" for i in range(n_seeds)],
 739                load_episodes=[load_episode for _ in range(n_seeds)],
 740                parameters_updates=[parameters_update for _ in range(n_seeds)],
 741                load_epochs=[load_epoch for _ in range(n_seeds)],
 742                load_searches=[load_search for _ in range(n_seeds)],
 743                load_parameters=[load_parameters for _ in range(n_seeds)],
 744                round_to_binary=[round_to_binary for _ in range(n_seeds)],
 745                load_strict=[load_strict for _ in range(n_seeds)],
 746                suppress_name_check=True,
 747                force=force,
 748                remove_saved_features=remove_saved_features,
 749            )
 750        else:
 751            print(f"TRAINING {episode_name}")
 752            try:
 753                task, parameters = self._make_task_training(
 754                    episode_name,
 755                    load_episode,
 756                    parameters_update,
 757                    load_epoch,
 758                    load_search,
 759                    load_parameters,
 760                    round_to_binary,
 761                    continuing=False,
 762                    task=task,
 763                    mask_name=mask_name,
 764                    load_strict=load_strict,
 765                )
 766                self._save_episode(
 767                    episode_name,
 768                    parameters,
 769                    task.behaviors_dict(),
 770                    norm_stats=task.get_normalization_stats(),
 771                )
 772                time_start = time.time()
 773                if trial is not None:
 774                    trial, metric = trial
 775                else:
 776                    trial, metric = None, None
 777                logs = task.train(
 778                    autostop_metric=autostop_metric,
 779                    autostop_interval=autostop_interval,
 780                    autostop_threshold=autostop_threshold,
 781                    loading_bar=loading_bar,
 782                    trial=trial,
 783                    optimized_metric=metric,
 784                )
 785                time_end = time.time()
 786                time_total = time_end - time_start
 787                hours = int(time_total // 3600)
 788                time_total -= hours * 3600
 789                minutes = int(time_total // 60)
 790                time_total -= minutes * 60
 791                seconds = int(time_total)
 792                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 793                self._update_episode_results(episode_name, logs, training_time)
 794                if remove_saved_features:
 795                    self._remove_stores(parameters)
 796                print("\n")
 797                return task
 798
 799            except Exception as e:
 800                if isinstance(e, optuna.exceptions.TrialPruned):
 801                    raise e
 802                else:
 803                    # if str(e) != f"The {episode_name} episode name is already in use!":
 804                    #     self.remove_episode(episode_name)
 805                    raise RuntimeError(f"Episode {episode_name} could not run")
 806
 807    def run_episodes(
 808        self,
 809        episode_names: List,
 810        load_episodes: List = None,
 811        parameters_updates: List = None,
 812        load_epochs: List = None,
 813        load_searches: List = None,
 814        load_parameters: List = None,
 815        round_to_binary: List = None,
 816        load_strict: List = None,
 817        force: bool = False,
 818        suppress_name_check: bool = False,
 819        remove_saved_features: bool = False,
 820    ) -> TaskDispatcher:
 821        """
 822        Run multiple episodes in sequence (and re-use previously loaded information)
 823
 824        For each episode, the task parameters are read from the config files and then updated with the
 825        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
 826        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 827        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 828        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 829        data parameters are used.
 830
 831        Parameters
 832        ----------
 833        episode_names : list
 834            a list of strings of episode names
 835        load_episodes : list, optional
 836            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
 837            the new episode will have the same number of runs, each starting with one of the pre-trained models
 838        parameters_updates : list, optional
 839            a list of dictionaries used to update the parameters from the config
 840        load_epochs : list, optional
 841            a list of integers used to specify the epoch to load (if load_episodes is not None)
 842        load_searches : list, optional
 843            a list of strings of hyperparameter search results to load
 844        load_parameters : list, optional
 845            a list of lists of string names of the parameters to load from the searches
 846        round_to_binary : list, optional
 847            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 848        load_strict : list, optional
 849            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
 850            the corresponding episode and differences in parameter name lists and
 851            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
 852            every episode)
 853        force : bool, default False
 854            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
 855        suppress_name_check : bool, default False
 856            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 857            why they are usually forbidden)
 858        remove_saved_features : bool, default False
 859            if `True`, the dataset will be deleted after training
 860        """
 861
 862        task = None
 863        if load_searches is None:
 864            load_searches = [None for _ in episode_names]
 865        if load_episodes is None:
 866            load_episodes = [None for _ in episode_names]
 867        if parameters_updates is None:
 868            parameters_updates = [None for _ in episode_names]
 869        if load_parameters is None:
 870            load_parameters = [None for _ in episode_names]
 871        if load_epochs is None:
 872            load_epochs = [None for _ in episode_names]
 873        if load_strict is None:
 874            load_strict = [True for _ in episode_names]
 875        for (
 876            parameters_update,
 877            episode_name,
 878            load_episode,
 879            load_epoch,
 880            load_search,
 881            load_parameters_list,
 882            load_strict_value,
 883        ) in zip(
 884            parameters_updates,
 885            episode_names,
 886            load_episodes,
 887            load_epochs,
 888            load_searches,
 889            load_parameters,
 890            load_strict,
 891        ):
 892            task = self.run_episode(
 893                episode_name,
 894                load_episode,
 895                parameters_update,
 896                task,
 897                load_epoch,
 898                load_search,
 899                load_parameters_list,
 900                round_to_binary,
 901                load_strict_value,
 902                suppress_name_check=suppress_name_check,
 903                force=force,
 904                remove_saved_features=remove_saved_features,
 905            )
 906        return task
 907
 908    def continue_episode(
 909        self,
 910        episode_name: str,
 911        num_epochs: int = None,
 912        task: TaskDispatcher = None,
 913        n_seeds: int = 1,
 914        remove_saved_features: bool = False,
 915        device: str = "cuda",
 916        num_cpus: int = None,
 917    ) -> TaskDispatcher:
 918        """
 919        Load an older episode and continue running from the latest checkpoint
 920
 921        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
 922
 923        Parameters
 924        ----------
 925        episode_name : str
 926            the name of the episode to continue
 927        num_epochs : int, optional
 928            the new number of epochs
 929        task : TaskDispatcher, optional
 930            a pre-existing task; if provided, the method will update the task instead of creating a new one
 931            (this might save time, mainly on dataset loading)
 932        result_average_interval : int, default 5
 933            the metric are averaged over the last result_average_interval to be stored in the episodes meta file
 934            and displayed by list_episodes() function (the full log is still always available)
 935        n_seeds : int, default 1
 936            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name::run_index`, e.g.
 937            `test_episode::0` and `test_episode::1`
 938        remove_saved_features : bool, default False
 939            if `True`, pre-computed features will be deleted after the run
 940        device : str, default "cuda"
 941            the torch device to use
 942        """
 943
 944        runs = self._episodes().get_runs(episode_name)
 945        for run in runs:
 946            print(f"TRAINING {run}")
 947            if num_epochs is None and not self._episode(run).unfinished():
 948                continue
 949            parameters_update = {
 950                "training": {
 951                    "num_epochs": num_epochs,
 952                    "device": device,
 953                },
 954                "general": {"num_cpus": num_cpus},
 955            }
 956            task, parameters = self._make_task_training(
 957                run,
 958                load_episode=run,
 959                parameters_update=parameters_update,
 960                continuing=True,
 961                task=task,
 962            )
 963            time_start = time.time()
 964            logs = task.train()
 965            time_end = time.time()
 966            old_time = self._training_time(run)
 967            if not np.isnan(old_time):
 968                time_end += old_time
 969                time_total = time_end - time_start
 970                hours = int(time_total // 3600)
 971                time_total -= hours * 3600
 972                minutes = int(time_total // 60)
 973                time_total -= minutes * 60
 974                seconds = int(time_total)
 975                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 976            else:
 977                training_time = np.nan
 978            self._save_episode(
 979                run,
 980                parameters,
 981                task.behaviors_dict(),
 982                suppress_validation=True,
 983                training_time=training_time,
 984                norm_stats=task.get_normalization_stats(),
 985            )
 986            self._update_episode_results(run, logs)
 987            print("\n")
 988        if len(runs) < n_seeds:
 989            for i in range(len(runs), n_seeds):
 990                self.run_episode(
 991                    f"{episode_name}::{i}",
 992                    parameters_update=self._episodes().load_parameters(runs[0]),
 993                    task=task,
 994                    suppress_name_check=True,
 995                )
 996        if remove_saved_features:
 997            self._remove_stores(parameters)
 998        return task
 999
1000    def run_default_hyperparameter_search(
1001        self,
1002        search_name: str,
1003        model_name: str = None,
1004        metric: str = "f1",
1005        best_n: int = 3,
1006        direction: str = "maximize",
1007        load_episode: str = None,
1008        load_epoch: int = None,
1009        load_strict: bool = True,
1010        prune: bool = True,
1011        force: bool = False,
1012        remove_saved_features: bool = False,
1013        overlap: float = 0,
1014        num_epochs: int = 50,
1015        test_frac: float = 0,
1016        n_trials=150,
1017        device: str = None,
1018    ):
1019        """
1020        Run an optuna hyperparameter search with default parameters for a model
1021
1022        For the vast majority of cases, optimizing the default parameters should be enough.
1023        Check out `dlc2action.options.model_hyperparameters` for the lists of parameters.
1024        There are also options to set overlap, test fraction and number of epochs parameters for the search without
1025        modifying the project config files. However, if you want something more complex, look into
1026        `Project.run_hyperparameter_search`.
1027
1028        The task parameters are read from the config files and updated with the parameters_update dictionary.
1029        The model can be either initialized from scratch or loaded from a previously run episode.
1030        For each trial, the objective metric is averaged over a few best epochs.
1031
1032        Parameters
1033        ----------
1034        search_name : str
1035            the name of the search to store it in the meta files and load in run_episode
1036        model_name : str, optional
1037            the name of the model (by default loaded from the project settings, see `project.help('models')` for options)
1038        metric : str, default f1
1039            the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to
1040            `"none"` in the config files, it will be reset to `"macro"` for the search; see `project.help('metrics')` for options
1041        n_trials : int, default 20
1042            the number of optimization trials to run
1043        best_n : int, default 1
1044            the number of epochs to average the metric; if 0, the last value is taken
1045        parameters_update : dict, optional
1046            the parameters update dictionary
1047        direction : {'maximize', 'minimize'}
1048            optimization direction
1049        load_episode : str, optional
1050            the name of the episode to load the model from
1051        load_epoch : int, optional
1052            the epoch to load the model from (if not provided, the last checkpoint is used)
1053        prune : bool, default False
1054            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1055            (with optuna HyperBand pruner)
1056        force : bool, default False
1057            if `True`, existing searches with the same name will be overwritten
1058        remove_saved_features : bool, default False
1059            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1060        device : str, optional
1061            cuda:{i} or cpu, if not given it is read from the default parameters
1062
1063        Returns
1064        -------
1065        dict
1066            a dictionary of best parameters
1067        """
1068
1069        if model_name is None:
1070            model_name = self._read_parameters()["general"]["model_name"]
1071        if model_name not in options.model_hyperparameters:
1072            raise ValueError(
1073                f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()"
1074            )
1075        pars = {
1076            "general": {
1077                "overlap": overlap,
1078                "model_name": model_name,
1079                "metric_functions": {metric},
1080            },
1081            "training": {"num_epochs": num_epochs},
1082        }
1083        if test_frac is not None:
1084            pars["training"]["test_frac"] = test_frac
1085        if not metric.split("_")[-1].isnumeric():
1086            project_pars = self._read_parameters()
1087            if project_pars["metrics"][metric].get("average") == "none":
1088                pars["metrics"] = {metric: {"average": "macro"}}
1089        if device is not None:
1090            pars["training"]["device"] = device
1091        return self.run_hyperparameter_search(
1092            search_name=search_name,
1093            search_space=options.model_hyperparameters[model_name],
1094            metric=metric,
1095            n_trials=n_trials,
1096            best_n=best_n,
1097            parameters_update=pars,
1098            direction=direction,
1099            load_episode=load_episode,
1100            load_epoch=load_epoch,
1101            load_strict=load_strict,
1102            prune=prune,
1103            force=force,
1104            remove_saved_features=remove_saved_features,
1105        )
1106
1107    def run_hyperparameter_search(
1108        self,
1109        search_name: str,
1110        search_space: Dict,
1111        metric: str = "f1",
1112        n_trials: int = 20,
1113        best_n: int = 1,
1114        parameters_update: Dict = None,
1115        direction: str = "maximize",
1116        load_episode: str = None,
1117        load_epoch: int = None,
1118        load_strict: bool = True,
1119        prune: bool = False,
1120        force: bool = False,
1121        remove_saved_features: bool = False,
1122    ) -> Dict:
1123        """
1124        Run an optuna hyperparameter search
1125
1126        For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`.
1127
1128        To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is
1129        a dictionary where keys are model names and values are default search spaces.
1130
1131        The task parameters are read from the config files and updated with the parameters_update dictionary.
1132        The model can be either initialized from scratch or loaded from a previously run episode.
1133        For each trial, the objective metric is averaged over a few best epochs.
1134
1135        Parameters
1136        ----------
1137        search_name : str
1138            the name of the search to store it in the meta files and load in run_episode
1139        search_space : dict
1140            a dictionary representing the search space; of this general structure:
1141            {'group/param_name': ('float/int/float_log/int_log', start, end),
1142            'group/param_name': ('categorical', [choices])}, e.g.
1143            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
1144            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])};
1145        metric : str, default f1
1146            the metric to maximize/minimize (see direction)
1147        n_trials : int, default 20
1148            the number of optimization trials to run
1149        best_n : int, default 1
1150            the number of epochs to average the metric; if 0, the last value is taken
1151        parameters_update : dict, optional
1152            the parameters update dictionary
1153        direction : {'maximize', 'minimize'}
1154            optimization direction
1155        load_episode : str, optional
1156            the name of the episode to load the model from
1157        load_epoch : int, optional
1158            the epoch to load the model from (if not provided, the last checkpoint is used)
1159        prune : bool, default False
1160            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1161            (with optuna HyperBand pruner)
1162        force : bool, default False
1163            if `True`, existing searches with the same name will be overwritten
1164        remove_saved_features : bool, default False
1165            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1166
1167        Returns
1168        -------
1169        dict
1170            a dictionary of best parameters
1171        """
1172
1173        self._check_search_validity(search_name, force=force)
1174        print(f"SEARCH {search_name}")
1175        self.remove_episode(f"_{search_name}")
1176        if parameters_update is None:
1177            parameters_update = {}
1178        parameters_update = self._update(
1179            parameters_update, {"general": {"metric_functions": {metric}}}
1180        )
1181        parameters = self._make_parameters(
1182            f"_{search_name}",
1183            load_episode,
1184            parameters_update,
1185            parameters_update_second={"training": {"model_save_path": None}},
1186            load_epoch=load_epoch,
1187            load_strict=load_strict,
1188        )
1189        task = None
1190
1191        if prune:
1192            pruner = optuna.pruners.HyperbandPruner()
1193        else:
1194            pruner = optuna.pruners.NopPruner()
1195        study = optuna.create_study(direction=direction, pruner=pruner)
1196        runner = _Runner(
1197            search_space=search_space,
1198            load_episode=load_episode,
1199            load_epoch=load_epoch,
1200            metric=metric,
1201            average=best_n,
1202            task=task,
1203            remove_saved_features=remove_saved_features,
1204            project=self,
1205            search_name=search_name,
1206        )
1207        study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials)
1208        search_path = self._search_path(search_name)
1209        os.mkdir(search_path)
1210        fig = optuna.visualization.plot_contour(study)
1211        plotly.offline.plot(
1212            fig, filename=os.path.join(search_path, f"{search_name}_contour.html")
1213        )
1214        fig = optuna.visualization.plot_param_importances(study)
1215        plotly.offline.plot(
1216            fig, filename=os.path.join(search_path, f"{search_name}_importances.html")
1217        )
1218        best_params = study.best_params
1219        best_value = study.best_value
1220        self._save_search(
1221            search_name,
1222            parameters,
1223            n_trials,
1224            best_params,
1225            best_value,
1226            metric,
1227            search_space,
1228        )
1229        self.remove_episode(f"_{search_name}")
1230        runner.clean()
1231        print(f"best parameters: {best_params}")
1232        print("\n")
1233        return best_params
1234
1235    def run_prediction(
1236        self,
1237        prediction_name: str,
1238        episode_names: List,
1239        load_epochs: List = None,
1240        parameters_update: Dict = None,
1241        augment_n: int = 10,
1242        data_path: str = None,
1243        mode: str = "all",
1244        file_paths: Set = None,
1245        remove_saved_features: bool = False,
1246        submission: bool = False,
1247        frame_number_map_file: str = None,
1248        force: bool = False,
1249        embedding: bool = False,
1250    ) -> None:
1251        """
1252        Load models from previously run episodes to generate a prediction
1253
1254        The probabilities predicted by the models are averaged.
1255        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1256        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1257        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1258        are the prediction arrays.
1259
1260        Parameters
1261        ----------
1262        prediction_name : str
1263            the name of the prediction
1264        episode_names : list
1265            a list of string episode names to load the models from
1266        load_epochs : list, optional
1267            a list of integer epoch indices to load the model from; if None, the last ones are used
1268        parameters_update : dict, optional
1269            a dictionary of parameter updates
1270        augment_n : int, default 10
1271            the number of augmentations to average over
1272        data_path : str, optional
1273            the data path to run the prediction for
1274        mode : {'all', 'test', 'val', 'train'}
1275            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1276        file_paths : set, optional
1277            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1278            for
1279        remove_saved_features : bool, default False
1280            if `True`, pre-computed features will be deleted
1281        submission : bool, default False
1282            if `True`, a MABe-22 style submission file is generated
1283        frame_number_map_file : str, optional
1284            path to the frame number map file
1285        force : bool, default False
1286            if `True`, existing prediction with this name will be overwritten
1287        """
1288
1289        self._check_prediction_validity(prediction_name, force=force)
1290        print(f"PREDICTION {prediction_name}")
1291        if submission:
1292            task = ...
1293            # TODO: add submission option to _make_prediction
1294            predicted = task.generate_submission(
1295                frame_number_map_file=frame_number_map_file,
1296                dataset=mode,
1297                augment_n=augment_n,
1298            )
1299            folder = os.path.join(
1300                self.project_path,
1301                "results",
1302                "predictions",
1303                f"{prediction_name}",
1304            )
1305            filename = os.path.join(folder, f"{prediction_name}.npy")
1306            np.save(filename, predicted, allow_pickle=True)
1307        else:
1308            try:
1309                (
1310                    task,
1311                    parameters,
1312                    mode,
1313                    prediction,
1314                    inference_time,
1315                ) = self._make_prediction(
1316                    prediction_name,
1317                    episode_names,
1318                    load_epochs,
1319                    parameters_update,
1320                    data_path,
1321                    file_paths,
1322                    mode,
1323                    augment_n,
1324                    evaluate=False,
1325                    embedding=embedding,
1326                )
1327                predicted = task.dataset(mode).generate_full_length_prediction(
1328                    prediction
1329                )
1330            except ValueError:
1331                (
1332                    task,
1333                    parameters,
1334                    mode,
1335                    predicted,
1336                    inference_time,
1337                ) = self._aggregate_predictions(
1338                    prediction_name,
1339                    episode_names,
1340                    load_epochs,
1341                    parameters_update,
1342                    data_path,
1343                    file_paths,
1344                    mode,
1345                    augment_n,
1346                    evaluate=False,
1347                    embedding=embedding,
1348                )
1349            folder = self.prediction_path(prediction_name)
1350            os.mkdir(folder)
1351            for video_id, prediction in predicted.items():
1352                with open(
1353                    os.path.join(
1354                        folder, video_id + f"_{prediction_name}_prediction.pickle"
1355                    ),
1356                    "wb",
1357                ) as f:
1358                    prediction["min_frames"], prediction["max_frames"] = task.dataset(
1359                        mode
1360                    ).get_min_max_frames(video_id)
1361                    behavior_indices = sorted(
1362                        [key for key in task.behaviors_dict() if key != -100]
1363                    )
1364                    prediction["behaviors"] = [
1365                        task.behaviors_dict()[key] for key in behavior_indices
1366                    ]
1367                    pickle.dump(prediction, f)
1368        if remove_saved_features:
1369            self._remove_stores(parameters)
1370        self._save_prediction(
1371            prediction_name,
1372            parameters,
1373            task.behaviors_dict(),
1374            embedding,
1375            inference_time,
1376        )
1377        print("\n")
1378
1379    def evaluate_prediction(
1380        self,
1381        prediction_name: str,
1382        parameters_update: Dict = None,
1383        data_path: str = None,
1384        file_paths: Set = None,
1385        mode: str = None,
1386        remove_saved_features: bool = False,
1387    ) -> Tuple[float, dict]:
1388
1389        with open(
1390            os.path.join(
1391                self.project_path, "results", "predictions", f"{prediction_name}.pickle"
1392            ),
1393            "rb",
1394        ) as f:
1395            prediction = pickle.load(f)
1396        if parameters_update is None:
1397            parameters_update = {}
1398        parameters_update = self._update(
1399            self._predictions().load_parameters(prediction_name), parameters_update
1400        )
1401        parameters_update.pop("model")
1402        task, parameters, mode = self._make_task_prediction(
1403            "_",
1404            load_episode=None,
1405            parameters_update=parameters_update,
1406            data_path=data_path,
1407            file_paths=file_paths,
1408            mode=mode,
1409        )
1410        results = task.evaluate_prediction(prediction, data=mode)
1411        if remove_saved_features:
1412            self._remove_stores(parameters)
1413        print("\n")
1414        return results
1415
1416    def evaluate(
1417        self,
1418        episode_names: List,
1419        load_epochs: List = None,
1420        augment_n: int = 0,
1421        data_path: str = None,
1422        file_paths: Set = None,
1423        mode: str = None,
1424        parameters_update: Dict = None,
1425        multiple_episode_policy: str = "average",
1426        remove_saved_features: bool = False,
1427        skip_updating_meta: bool = True,
1428    ) -> Dict:
1429        """
1430        Load one or several models from previously run episodes to make an evaluation
1431
1432        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1433
1434        Parameters
1435        ----------
1436        episode_names : list
1437            a list of string episode names to load the models from
1438        load_epochs : list, optional
1439            a list of integer epoch indices to load the model from; if None, the last ones are used
1440        augment_n : int, default 0
1441            the number of augmentations to average over
1442        data_path : str, optional
1443            the data path to run the prediction for
1444        file_paths : set, optional
1445            a set of files to run the prediction for
1446        mode : {'test', 'val', 'train', 'all'}
1447            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1448            by default 'test' if test subset is not empty and 'val' otherwise)
1449        parameters_update : dict, optional
1450            a dictionary with parameter updates (cannot change model parameters)
1451        remove_saved_features : bool, default False
1452            if `True`, the dataset will be deleted
1453
1454        Returns
1455        -------
1456        metric : dict
1457            a dictionary of average values of metric functions
1458        """
1459
1460        names = []
1461        for episode_name in episode_names:
1462            names += self._episodes().get_runs(episode_name)
1463        if len(set(episode_names)) == 1:
1464            print(f"EVALUATION {episode_names[0]}")
1465        else:
1466            print(f"EVALUATION {episode_names}")
1467        if len(names) > 1:
1468            evaluate = True
1469        else:
1470            evaluate = False
1471        if multiple_episode_policy == "average":
1472            try:
1473                (
1474                    task,
1475                    parameters,
1476                    mode,
1477                    prediction,
1478                    inference_time,
1479                ) = self._make_prediction(
1480                    "_",
1481                    episode_names,
1482                    load_epochs,
1483                    parameters_update,
1484                    mode=mode,
1485                    data_path=data_path,
1486                    file_paths=file_paths,
1487                    augment_n=augment_n,
1488                    evaluate=evaluate,
1489                )
1490            except:
1491                (
1492                    task,
1493                    parameters,
1494                    mode,
1495                    prediction,
1496                    inference_time,
1497                ) = self._aggregate_predictions(
1498                    "_",
1499                    episode_names,
1500                    load_epochs,
1501                    parameters_update,
1502                    mode=mode,
1503                    data_path=data_path,
1504                    file_paths=file_paths,
1505                    augment_n=augment_n,
1506                    evaluate=evaluate,
1507                )
1508            print("AGGREGATED:")
1509            _, results = task.evaluate_prediction(prediction, data=mode)
1510            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1511                self._update_episode_metrics(names[0], results)
1512        elif multiple_episode_policy == "statistics":
1513            values = defaultdict(lambda: [])
1514            task = None
1515            for name in names:
1516                (
1517                    task,
1518                    parameters,
1519                    mode,
1520                    prediction,
1521                    inference_time,
1522                ) = self._make_prediction(
1523                    "_",
1524                    [name],
1525                    load_epochs,
1526                    parameters_update,
1527                    mode=mode,
1528                    data_path=data_path,
1529                    file_paths=file_paths,
1530                    augment_n=augment_n,
1531                    evaluate=evaluate,
1532                    task=task,
1533                )
1534                _, metrics = task.evaluate_prediction(prediction, data=mode)
1535                for name, value in metrics.items():
1536                    values[name].append(value)
1537                if mode == "val" and not skip_updating_meta:
1538                    self._update_episode_metrics(name, metrics)
1539            results = defaultdict(lambda: {})
1540            mean_string = ""
1541            std_string = ""
1542            for key, value_list in values.items():
1543                results[key]["mean"] = np.mean(value_list)
1544                results[key]["std"] = np.std(value_list)
1545                mean_string += f"{key} {np.mean(value_list):.3f}, "
1546                std_string += f"{key} {np.std(value_list):.3f}, "
1547            print("MEAN:")
1548            print(mean_string)
1549            print("STD:")
1550            print(std_string)
1551        else:
1552            raise ValueError(
1553                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1554                f"from ['average', 'statistics']"
1555            )
1556        if len(names) > 0 and remove_saved_features:
1557            self._remove_stores(parameters)
1558        print(f"Inference time: {inference_time}")
1559        print("\n")
1560        return results
1561
1562    def _generate_similarity_score(
1563        self,
1564        prediction_name: str,
1565        target_video_id: str,
1566        target_clip: str,
1567        target_start: int,
1568        target_end: int,
1569    ) -> Dict:
1570        with open(
1571            os.path.join(
1572                self.project_path,
1573                "results",
1574                "predictions",
1575                f"{prediction_name}.pickle",
1576            ),
1577            "rb",
1578        ) as f:
1579            prediction = pickle.load(f)
1580        target = prediction[target_video_id][target_clip][:, target_start:target_end]
1581        score_dict = defaultdict(lambda: {})
1582        for video_id in prediction:
1583            for clip_id in prediction[video_id]:
1584                score_dict[video_id][clip_id] = torch.cdist(
1585                    target.T, prediction[video_id][score_dict].T
1586                ).min(0)
1587        return score_dict
1588
1589    def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict:
1590        interval_address = {}
1591        interval_value = {}
1592        s = 0
1593        n = 0
1594        for video_id, video_dict in score_dict.items():
1595            for clip_id, value in video_dict.items():
1596                s += value.mean()
1597                n += 1
1598        mean_value = s / n
1599        alpha = 1.75
1600        for it in range(10):
1601            id = 0
1602            interval_address = {}
1603            interval_value = {}
1604            for video_id, video_dict in score_dict.items():
1605                for clip_id, value in video_dict.items():
1606                    res_indices_start, res_indices_end = apply_threshold(
1607                        value,
1608                        threshold=(2 - alpha * (0.9**it)) * mean_value,
1609                        low=True,
1610                        error_mask=None,
1611                        min_frames=min_length,
1612                        smooth_interval=0,
1613                    )
1614                    for start, end in zip(res_indices_start, res_indices_end):
1615                        interval_address[id] = [video_id, clip_id, start, end]
1616                        interval_value[id] = score_dict[video_id][clip_id][
1617                            start:end
1618                        ].mean()
1619                        id += 1
1620            if len(interval_address) >= n_intervals:
1621                break
1622        if len(interval_address) < n_intervals:
1623            warnings.warn(
1624                f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals"
1625            )
1626        sorted_intervals = sorted(
1627            interval_value.items(), key=lambda x: x[1], reverse=True
1628        )
1629        output_intervals = [
1630            interval_address[x[0]]
1631            for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)]
1632        ]
1633        output = defaultdict(lambda: [])
1634        for video_id, clip_id, start, end in output_intervals:
1635            output[video_id].append([start, end, clip_id])
1636        return output
1637
1638    def list_episodes(
1639        self,
1640        episode_names: List = None,
1641        value_filter: str = "",
1642        display_parameters: List = None,
1643        print_results: bool = True,
1644    ) -> pd.DataFrame:
1645        """
1646        Get a filtered pandas dataframe with episode metadata
1647
1648        Parameters
1649        ----------
1650        episode_names : list
1651            a list of strings of episode names
1652        value_filter : str
1653            a string of filters to apply; of this general structure:
1654            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
1655            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
1656        display_parameters : list
1657            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1658        print_results : bool, default True
1659            if True, the result will be printed to standard output
1660
1661        Returns
1662        -------
1663        pd.DataFrame
1664            the filtered dataframe
1665        """
1666
1667        episodes = self._episodes().list_episodes(
1668            episode_names, value_filter, display_parameters
1669        )
1670        if print_results:
1671            print("TRAINING EPISODES")
1672            print(episodes)
1673            print("\n")
1674        return episodes
1675
1676    def list_predictions(
1677        self,
1678        episode_names: List = None,
1679        value_filter: str = "",
1680        display_parameters: List = None,
1681        print_results: bool = True,
1682    ) -> pd.DataFrame:
1683        """
1684        Get a filtered pandas dataframe with prediction metadata
1685
1686        Parameters
1687        ----------
1688        episode_names : list
1689            a list of strings of episode names
1690        value_filter : str
1691            a string of filters to apply; of this general structure:
1692            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
1693            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
1694        display_parameters : list
1695            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1696        print_results : bool, default True
1697            if True, the result will be printed to standard output
1698
1699        Returns
1700        -------
1701        pd.DataFrame
1702            the filtered dataframe
1703        """
1704
1705        predictions = self._predictions().list_episodes(
1706            episode_names, value_filter, display_parameters
1707        )
1708        if print_results:
1709            print("PREDICTIONS")
1710            print(predictions)
1711            print("\n")
1712        return predictions
1713
1714    def list_searches(
1715        self,
1716        search_names: List = None,
1717        value_filter: str = "",
1718        display_parameters: List = None,
1719        print_results: bool = True,
1720    ) -> pd.DataFrame:
1721        """
1722        Get a filtered pandas dataframe with hyperparameter search metadata
1723
1724        Parameters
1725        ----------
1726        search_names : list
1727            a list of strings of search names
1728        value_filter : str
1729            a string of filters to apply; of this general structure:
1730            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
1731            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
1732        display_parameters : list
1733            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1734        print_results : bool, default True
1735            if True, the result will be printed to standard output
1736
1737        Returns
1738        -------
1739        pd.DataFrame
1740            the filtered dataframe
1741        """
1742
1743        searches = self._searches().list_episodes(
1744            search_names, value_filter, display_parameters
1745        )
1746        if print_results:
1747            print("SEARCHES")
1748            print(searches)
1749            print("\n")
1750        return searches
1751
1752    def get_best_parameters(
1753        self,
1754        search_name: str,
1755        round_to_binary: List = None,
1756    ):
1757        params, model = self._searches().get_best_params(
1758            search_name, round_to_binary=round_to_binary
1759        )
1760        params = self._update(params, {"general": {"model_name": model}})
1761        return params
1762
1763    def list_best_parameters(
1764        self, search_name: str, print_results: bool = True
1765    ) -> Dict:
1766        """
1767        Get the raw dictionary of best parameters found by a search
1768
1769        Parameters
1770        ----------
1771        search_name : str
1772            the name of the search
1773        print_results : bool, default True
1774            if True, the result will be printed to standard output
1775
1776        Returns
1777        -------
1778        best_params : dict
1779            a dictionary of the best parameters where the keys are in '{group}/{name}' format
1780        """
1781
1782        params = self._searches().get_best_params_raw(search_name)
1783        if print_results:
1784            print(f"SEARCH RESULTS {search_name}")
1785            for k, v in params.items():
1786                print(f"{k}: {v}")
1787            print("\n")
1788        return params
1789
1790    def plot_episodes(
1791        self,
1792        episode_names: List,
1793        metrics: List,
1794        modes: List = None,
1795        title: str = None,
1796        episode_labels: List = None,
1797        save_path: str = None,
1798        add_hlines: List = None,
1799        epoch_limits: List = None,
1800        colors: List = None,
1801        add_highpoint_hlines: bool = False,
1802    ) -> None:
1803        """
1804        Plot episode training curves
1805
1806        Parameters
1807        ----------
1808        episode_names : list
1809            a list of episode names to plot; to plot to episodes in one line combine them in a list
1810            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
1811        metrics : list
1812            a list of metric to plot
1813        modes : list, optional
1814            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
1815        title : str, optional
1816            title for the plot
1817        episode_labels : list, optional
1818            a list of strings used to label the curves (has to be the same length as episode_names)
1819        save_path : str, optional
1820            the path to save the resulting plot
1821        add_hlines : list, optional
1822            a list of float values (or (value, label) tuples) to mark with horizontal lines
1823        colors: list, optional
1824            a list of matplotlib colors
1825        add_highpoint_hlines : bool, default False
1826            if `True`, horizontal lines will be added at the highest value of each episode
1827        """
1828
1829        if modes is None:
1830            modes = ["val"]
1831        if add_hlines is None:
1832            add_hlines = []
1833        logs = []
1834        epochs = []
1835        labels = []
1836        if episode_labels is not None:
1837            assert len(episode_labels) == len(episode_names)
1838        for name_i, name in enumerate(episode_names):
1839            log_params = product(metrics, modes)
1840            for metric, mode in log_params:
1841                if episode_labels is not None:
1842                    label = episode_labels[name_i]
1843                else:
1844                    label = deepcopy(name)
1845                if len(modes) != 1:
1846                    label += f"_{mode}"
1847                if len(metrics) != 1:
1848                    label += f"_{metric}"
1849                labels.append(label)
1850                if isinstance(name, Iterable) and not isinstance(name, str):
1851                    epoch_list = defaultdict(lambda: [])
1852                    multi_logs = defaultdict(lambda: [])
1853                    for i, n in enumerate(name):
1854                        runs = self._episodes().get_runs(n)
1855                        if len(runs) > 1:
1856                            for run in runs:
1857                                index = run.split("::")[-1]
1858                                if multi_logs[index] == []:
1859                                    if multi_logs["null"] is None:
1860                                        raise RuntimeError(
1861                                            "The run indices are not consistent across episodes!"
1862                                        )
1863                                    else:
1864                                        multi_logs[index] += multi_logs["null"]
1865                                multi_logs[index] += list(
1866                                    self._episode(run).get_metric_log(mode, metric)
1867                                )
1868                                start = (
1869                                    0
1870                                    if len(epoch_list[index]) == 0
1871                                    else epoch_list[index][-1]
1872                                )
1873                                epoch_list[index] += [
1874                                    x + start
1875                                    for x in self._episode(run).get_epoch_list(mode)
1876                                ]
1877                            multi_logs["null"] = None
1878                        else:
1879                            if len(multi_logs.keys()) > 1:
1880                                raise RuntimeError(
1881                                    "Cannot plot a single-run episode after a multi-run episode!"
1882                                )
1883                            multi_logs["null"] += list(
1884                                self._episode(n).get_metric_log(mode, metric)
1885                            )
1886                            start = (
1887                                0
1888                                if len(epoch_list["null"]) == 0
1889                                else epoch_list["null"][-1]
1890                            )
1891                            epoch_list["null"] += [
1892                                x + start for x in self._episode(n).get_epoch_list(mode)
1893                            ]
1894                    if len(multi_logs.keys()) == 1:
1895                        log = multi_logs["null"]
1896                        epochs.append(epoch_list["null"])
1897                    else:
1898                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
1899                        epochs.append(
1900                            tuple([v for k, v in epoch_list.items() if k != "null"])
1901                        )
1902                else:
1903                    runs = self._episodes().get_runs(name)
1904                    if len(runs) > 1:
1905                        log = []
1906                        for run in runs:
1907                            tracked_metrics = self._episode(run).get_metrics()
1908                            if metric in tracked_metrics:
1909                                log.append(
1910                                    list(
1911                                        self._episode(run).get_metric_log(mode, metric)
1912                                    )
1913                                )
1914                            else:
1915                                relevant = []
1916                                for m in tracked_metrics:
1917                                    m_split = m.split("_")
1918                                    if (
1919                                        "_".join(m_split[:-1]) == metric
1920                                        and m_split[-1].isnumeric()
1921                                    ):
1922                                        relevant.append(m)
1923                                if len(relevant) == 0:
1924                                    raise ValueError(
1925                                        f"The {metric} metric was not tracked at {run}"
1926                                    )
1927                                arr = 0
1928                                for m in relevant:
1929                                    arr += self._episode(run).get_metric_log(mode, m)
1930                                arr /= len(relevant)
1931                                log.append(list(arr))
1932                        log = tuple(log)
1933                        epochs.append(
1934                            tuple(
1935                                [
1936                                    self._episode(run).get_epoch_list(mode)
1937                                    for run in runs
1938                                ]
1939                            )
1940                        )
1941                    else:
1942                        tracked_metrics = self._episode(name).get_metrics()
1943                        if metric in tracked_metrics:
1944                            log = list(self._episode(name).get_metric_log(mode, metric))
1945                        else:
1946                            relevant = []
1947                            for m in tracked_metrics:
1948                                m_split = m.split("_")
1949                                if (
1950                                    "_".join(m_split[:-1]) == metric
1951                                    and m_split[-1].isnumeric()
1952                                ):
1953                                    relevant.append(m)
1954                            if len(relevant) == 0:
1955                                raise ValueError(
1956                                    f"The {metric} metric was not tracked at {name}"
1957                                )
1958                            arr = 0
1959                            for m in relevant:
1960                                arr += self._episode(name).get_metric_log(mode, m)
1961                            arr /= len(relevant)
1962                            log = list(arr)
1963                        epochs.append(self._episode(name).get_epoch_list(mode))
1964                logs.append(log)
1965        # if episode_labels is not None:
1966        #     print(f'{len(episode_labels)=}, {len(logs)=}')
1967        #     if len(episode_labels) != len(logs):
1968
1969        #         raise ValueError(
1970        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
1971        #             f"curves ({len(logs)})!"
1972        #         )
1973        #     else:
1974        #         labels = episode_labels
1975        if colors is None:
1976            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
1977        if len(colors) != len(logs):
1978            raise ValueError(
1979                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
1980            )
1981        plt.figure()
1982        length = 0
1983        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
1984            if type(log) is list:
1985                if len(log) > length:
1986                    length = len(log)
1987                plt.plot(
1988                    epoch_list,
1989                    log,
1990                    label=label,
1991                    color=color,
1992                )
1993                if add_highpoint_hlines:
1994                    plt.axhline(np.max(log), linestyle="dashed", color=color)
1995            else:
1996                for l, xx in zip(log, epoch_list):
1997                    if len(l) > length:
1998                        length = len(l)
1999                    plt.plot(
2000                        xx,
2001                        l,
2002                        color=color,
2003                        alpha=0.2,
2004                    )
2005                if not all([len(x) == len(log[0]) for x in log]):
2006                    warnings.warn(
2007                        f"Got logs with unequal lengths in parallel runs for {label}"
2008                    )
2009                    log = list(log)
2010                    epoch_list = list(epoch_list)
2011                    for i, x in enumerate(epoch_list):
2012                        to_remove = []
2013                        for j, y in enumerate(x[1:]):
2014                            if y <= x[j - 1]:
2015                                y_ind = x.index(y)
2016                                to_remove += list(range(y_ind, j))
2017                        epoch_list[i] = [
2018                            y for j, y in enumerate(x) if j not in to_remove
2019                        ]
2020                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2021                    length = min([len(x) for x in log])
2022                    for i in range(len(log)):
2023                        log[i] = log[i][:length]
2024                        epoch_list[i] = epoch_list[i][:length]
2025                    if not all([x == epoch_list[0] for x in epoch_list]):
2026                        raise RuntimeError(
2027                            f"Got different epoch indices in parallel runs for {label}"
2028                        )
2029                mean = np.array(log).mean(0)
2030                plt.plot(
2031                    epoch_list[0],
2032                    mean,
2033                    label=label,
2034                    color=color,
2035                )
2036                if add_highpoint_hlines:
2037                    plt.axhline(np.max(mean), linestyle="dashed", color=color)
2038        for x in add_hlines:
2039            label = None
2040            if isinstance(x, Iterable):
2041                x, label = x
2042            plt.axhline(x, label=label)
2043            plt.xlim((0, length))
2044
2045        plt.legend()
2046        plt.xlabel("epochs")
2047        if len(metrics) == 1:
2048            plt.ylabel(metrics[0])
2049        else:
2050            plt.ylabel("value")
2051        if title is None:
2052            if len(episode_names) == 1:
2053                title = episode_names[0]
2054            elif len(metrics) == 1:
2055                title = metrics[0]
2056        if epoch_limits is not None:
2057            plt.xlim(epoch_limits)
2058        if title is not None:
2059            plt.title(title)
2060        plt.show()
2061        if save_path is not None:
2062            plt.savefig(save_path)
2063
2064    def update_parameters(
2065        self,
2066        parameters_update: Dict = None,
2067        load_search: str = None,
2068        load_parameters: List = None,
2069        round_to_binary: List = None,
2070    ) -> None:
2071        """
2072        Update the parameters in the project config files
2073
2074        Parameters
2075        ----------
2076        parameters_update : dict, optional
2077            a dictionary of parameter updates
2078        load_search : str, optional
2079            the name of hyperparameter search results to load to config
2080        load_parameters : list, optional
2081            a list of lists of string names of the parameters to load from the searches
2082        round_to_binary : list, optional
2083            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2084        """
2085
2086        keys = [
2087            "general",
2088            "losses",
2089            "metrics",
2090            "ssl",
2091            "training",
2092            "data",
2093        ]
2094        parameters = self._read_parameters(catch_blanks=False)
2095        if parameters_update is not None:
2096            if "model" in parameters_update:
2097                model_params = parameters_update.pop("model")
2098            else:
2099                model_params = None
2100            if "features" in parameters_update:
2101                feat_params = parameters_update.pop("features")
2102            else:
2103                feat_params = None
2104            if "augmentations" in parameters_update:
2105                aug_params = parameters_update.pop("augmentations")
2106            else:
2107                aug_params = None
2108            parameters = self._update(parameters, parameters_update)
2109            model_name = parameters["general"]["model_name"]
2110            parameters["model"] = self._open_yaml(
2111                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2112            )
2113            if model_params is not None:
2114                parameters["model"] = self._update(parameters["model"], model_params)
2115            feat_name = parameters["general"]["feature_extraction"]
2116            parameters["features"] = self._open_yaml(
2117                os.path.join(
2118                    self.project_path, "config", "features", f"{feat_name}.yaml"
2119                )
2120            )
2121            if feat_params is not None:
2122                parameters["features"] = self._update(
2123                    parameters["features"], feat_params
2124                )
2125            aug_name = options.extractor_to_transformer[
2126                parameters["general"]["feature_extraction"]
2127            ]
2128            parameters["augmentations"] = self._open_yaml(
2129                os.path.join(
2130                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2131                )
2132            )
2133            if aug_params is not None:
2134                parameters["augmentations"] = self._update(
2135                    parameters["augmentations"], aug_params
2136                )
2137        if load_search is not None:
2138            parameters_update, model_name = self._searches().get_best_params(
2139                load_search, load_parameters, round_to_binary
2140            )
2141            parameters["general"]["model_name"] = model_name
2142            parameters["model"] = self._open_yaml(
2143                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2144            )
2145            parameters = self._update(parameters, parameters_update)
2146        for key in keys:
2147            with open(
2148                os.path.join(self.project_path, "config", f"{key}.yaml"), "w", encoding="utf-8"
2149            ) as f:
2150                YAML().dump(parameters[key], f)
2151        model_name = parameters["general"]["model_name"]
2152        model_path = os.path.join(
2153            self.project_path, "config", "model", f"{model_name}.yaml"
2154        )
2155        with open(model_path, "w", encoding="utf-8") as f:
2156            YAML().dump(parameters["model"], f)
2157        features_name = parameters["general"]["feature_extraction"]
2158        features_path = os.path.join(
2159            self.project_path, "config", "features", f"{features_name}.yaml"
2160        )
2161        with open(features_path, "w", encoding="utf-8") as f:
2162            YAML().dump(parameters["features"], f)
2163        aug_name = options.extractor_to_transformer[features_name]
2164        aug_path = os.path.join(
2165            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2166        )
2167        with open(aug_path, "w", encoding="utf-8") as f:
2168            YAML().dump(parameters["augmentations"], f)
2169
2170    def get_summary(
2171        self,
2172        episode_names: list,
2173        method: str = "last",
2174        average: int = 1,
2175        metrics: List = None,
2176    ) -> Dict:
2177        """
2178        Get a summary of episode statistics
2179
2180        If the episode has multiple runs, the statistics will be aggregated over all of them.
2181
2182        Parameters
2183        ----------
2184        episode_name : str
2185            the name of the episode
2186        method : ["best", "last"]
2187            the method for choosing the epochs
2188        average : int, default 1
2189            the number of epochs to average over (for each run)
2190        metrics : list, optional
2191            a list of metrics
2192
2193        Returns
2194        -------
2195        statistics : dict
2196            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
2197            and 'std' for the standard deviation
2198        """
2199
2200        runs = []
2201        for episode_name in episode_names:
2202            runs_ep = self._episodes().get_runs(episode_name)
2203            if len(runs_ep) == 0:
2204                raise RuntimeError(
2205                    f"There is no {episode_name} episode in the project memory"
2206                )
2207            runs += runs_ep
2208        if metrics is None:
2209            metrics = self._episode(runs[0]).get_metrics()
2210
2211        values = {m: [] for m in metrics}
2212        for run in runs:
2213            for m in metrics:
2214                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
2215                if method == "best":
2216                    log = sorted(log)
2217                    values[m] += list(log[-average:])
2218                elif method == "last":
2219                    if len(log) == 0:
2220                        episodes = self._episodes().data
2221                        if average == 1 and ("results", m) in episodes.columns:
2222                            values[m] += [episodes.loc[run, ("results", m)]]
2223                        else:
2224                            raise RuntimeError(f"Did not find {m} metric for {run} run")
2225                    values[m] += list(log[-average:])
2226                elif method.startswith("epoch"):
2227                    epoch = int(method[5:]) - 1
2228                    pars = self._episodes().load_parameters(run)
2229                    step = int(pars["training"]["validation_interval"])
2230                    values[m] += [log[epoch // step]]
2231                else:
2232                    raise ValueError(
2233                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
2234                    )
2235        statistics = defaultdict(lambda: {})
2236        for m, v in values.items():
2237            statistics[m]["mean"] = np.mean(v)
2238            statistics[m]["std"] = np.std(v)
2239        print(f"SUMMARY {episode_names}")
2240        for m, v in statistics.items():
2241            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
2242        print("\n")
2243        return dict(statistics)
2244
2245    @staticmethod
2246    def remove_project(name: str, projects_path: str = None) -> None:
2247        """
2248        Remove all project files and experiment records and results
2249        """
2250
2251        if projects_path is None:
2252            projects_path = os.path.join(str(Path.home()), "DLC2Action")
2253        project_path = os.path.join(projects_path, name)
2254        if os.path.exists(project_path):
2255            shutil.rmtree(project_path)
2256
2257    def remove_saved_features(
2258        self,
2259        dataset_names: List = None,
2260        exceptions: List = None,
2261        remove_active: bool = False,
2262    ) -> None:
2263        """
2264        Remove saved pre-computed dataset files
2265
2266        By default, all pre-computed features will be deleted.
2267        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
2268        while training or inference is happening though.
2269
2270        Parameters
2271        ----------
2272        dataset_names : list, optional
2273            a list of dataset names to delete (by default all names are added)
2274        exceptions : list, optional
2275            a list of dataset names to not be deleted
2276        remove_active : bool, default False
2277            if `False`, datasets used by unfinished episodes will not be deleted
2278        """
2279
2280        print("Removing datasets...")
2281        if dataset_names is None:
2282            dataset_names = []
2283        if exceptions is None:
2284            exceptions = []
2285        if not remove_active:
2286            exceptions += self._episodes().get_active_datasets()
2287        dataset_path = os.path.join(self.project_path, "saved_datasets")
2288        if os.path.exists(dataset_path):
2289            if dataset_names == []:
2290                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
2291
2292            to_remove = [
2293                x
2294                for x in dataset_names
2295                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
2296            ]
2297            if len(to_remove) > 2:
2298                to_remove = tqdm(to_remove)
2299            for dataset in to_remove:
2300                shutil.rmtree(os.path.join(dataset_path, dataset))
2301            to_remove = [
2302                f"{x}.pickle"
2303                for x in dataset_names
2304                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
2305                and x not in exceptions
2306            ]
2307            for dataset in to_remove:
2308                os.remove(os.path.join(dataset_path, dataset))
2309            names = self._saved_datasets().dataset_names()
2310            self._saved_datasets().remove(names)
2311        print("\n")
2312
2313    def remove_extra_checkpoints(
2314        self, episode_names: List = None, exceptions: List = None
2315    ) -> None:
2316        """
2317        Remove intermediate model checkpoint files (only leave the results of the last epoch)
2318
2319        By default, all intermediate checkpoints will be deleted.
2320        Files in the model folder that are not associated with any record in the meta files are also deleted.
2321
2322        Parameters
2323        ----------
2324        episode_names : list, optional
2325            a list of episode names to clean (by default all names are added)
2326        exceptions : list, optional
2327            a list of episode names to not clean
2328        """
2329
2330        model_path = os.path.join(self.project_path, "results", "model")
2331        try:
2332            all_names = self._episodes().data.index
2333        except:
2334            all_names = os.listdir(model_path)
2335        if episode_names is None:
2336            episode_names = all_names
2337        if exceptions is None:
2338            exceptions = []
2339        to_remove = [x for x in episode_names if x not in exceptions]
2340        folders = os.listdir(model_path)
2341        for folder in folders:
2342            if folder not in all_names:
2343                shutil.rmtree(os.path.join(model_path, folder))
2344            elif folder in to_remove:
2345                files = os.listdir(os.path.join(model_path, folder))
2346                for file in sorted(files)[:-1]:
2347                    os.remove(os.path.join(model_path, folder, file))
2348
2349    def remove_search(self, search_name: str) -> None:
2350        """
2351        Remove a hyperparameter search record
2352
2353        Parameters
2354        ----------
2355        search_name : str
2356            the name of the search to remove
2357        """
2358
2359        self._searches().remove_episode(search_name)
2360        graph_path = os.path.join(self.project_path, "results", "searches", search_name)
2361        if os.path.exists(graph_path):
2362            shutil.rmtree(graph_path)
2363
2364    def remove_prediction(self, prediction_name: str) -> None:
2365        """
2366        Remove a prediction record
2367
2368        Parameters
2369        ----------
2370        prediction_name : str
2371            the name of the prediction to remove
2372        """
2373
2374        self._predictions().remove_episode(prediction_name)
2375        prediction_path = os.path.join(
2376            self.project_path, "results", "predictions", prediction_name
2377        )
2378        if os.path.exists(prediction_path):
2379            shutil.rmtree(prediction_path)
2380
2381    def remove_episode(self, episode_name: str) -> None:
2382        """
2383        Remove all model, logs and metafile records related to an episode
2384
2385        Parameters
2386        ----------
2387        episode_name : str
2388            the name of the episode to remove
2389        """
2390
2391        runs = self._episodes().get_runs(episode_name)
2392        runs.append(episode_name)
2393        for run in runs:
2394            self._episodes().remove_episode(run)
2395            model_path = os.path.join(self.project_path, "results", "model", run)
2396            if os.path.exists(model_path):
2397                shutil.rmtree(model_path)
2398            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
2399            if os.path.exists(log_path):
2400                os.remove(log_path)
2401
2402    def prune_unfinished(self, exceptions: List = None) -> None:
2403        """
2404        Remove all interrupted episodes
2405
2406        Remove all episodes that either don't have a log file or have less epochs in the log file than in
2407        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
2408        currently running!
2409
2410        Parameters
2411        ----------
2412        exceptions : list
2413            the episodes to keep even if they are interrupted
2414
2415        Returns
2416        -------
2417        pruned : list
2418            a list of the episode names that were pruned
2419        """
2420
2421        if exceptions is None:
2422            exceptions = []
2423        unfinished = self._episodes().unfinished_episodes()
2424        unfinished = [x for x in unfinished if x not in exceptions]
2425        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
2426        unfinished += [
2427            x for x in model_folders if x not in self._episodes().list_episodes().index
2428        ]
2429        print(f"PRUNING {unfinished}")
2430        for episode_name in unfinished:
2431            self.remove_episode(episode_name)
2432        print(f"\n")
2433        return unfinished
2434
2435    def prediction_path(self, prediction_name: str) -> str:
2436        """
2437        Get the path where prediction files are saved
2438
2439        Parameters
2440        ----------
2441        prediction_name : str
2442            name of the prediction
2443
2444        Returns
2445        -------
2446        prediction_path : str
2447            the file path
2448        """
2449
2450        return os.path.join(
2451            self.project_path, "results", "predictions", f"{prediction_name}"
2452        )
2453
2454    @classmethod
2455    def print_data_types(cls):
2456        print("DATA TYPES:")
2457        for key, value in cls.data_types().items():
2458            print(f"{key}:")
2459            print(value.__doc__)
2460
2461    @classmethod
2462    def print_annotation_types(cls):
2463        print("ANNOTATION TYPES:")
2464        for key, value in cls.annotation_types().items():
2465            print(f"{key}:")
2466            print(value.__doc__)
2467
2468    @staticmethod
2469    def data_types() -> List:
2470        """
2471        Get available data types
2472
2473        Returns
2474        -------
2475        list
2476            available data types
2477        """
2478
2479        return options.input_stores
2480
2481    @staticmethod
2482    def annotation_types() -> List:
2483        """
2484        Get available annotation types
2485
2486        Returns
2487        -------
2488        list
2489            available annotation types
2490        """
2491
2492        return options.annotation_stores
2493
2494    def _save_mask(self, file: Dict, mask_name: str):
2495        """
2496        Save a mask file
2497        """
2498
2499        if not os.path.exists(self._mask_path()):
2500            os.mkdir(self._mask_path())
2501        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f:
2502            pickle.dump(file, f)
2503
2504    def _load_mask(self, mask_name: str) -> Dict:
2505        """
2506        Load a mask file
2507        """
2508
2509        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f:
2510            data = pickle.load(f)
2511        return data
2512
2513    def _thresholds(self) -> DecisionThresholds:
2514        """
2515        Get the decision thresholds meta object
2516        """
2517
2518        return DecisionThresholds(self._thresholds_path())
2519
2520    def _episodes(self) -> SavedRuns:
2521        """
2522        Get the episodes meta object
2523
2524        Returns
2525        -------
2526        episodes : SavedRuns
2527            the episodes meta object
2528        """
2529
2530        try:
2531            return SavedRuns(self._episodes_path(), self.project_path)
2532        except:
2533            self.load_metadata_backup()
2534            return SavedRuns(self._episodes_path(), self.project_path)
2535
2536    def _predictions(self) -> SavedRuns:
2537        """
2538        Get the predictions meta object
2539
2540        Returns
2541        -------
2542        predictions : SavedRuns
2543            the predictions meta object
2544        """
2545
2546        try:
2547            return SavedRuns(self._predictions_path(), self.project_path)
2548        except:
2549            self.load_metadata_backup()
2550            return SavedRuns(self._predictions_path(), self.project_path)
2551
2552    def _saved_datasets(self) -> SavedStores:
2553        """
2554        Get the datasets meta object
2555
2556        Returns
2557        -------
2558        datasets : SavedStores
2559            the datasets meta object
2560        """
2561
2562        try:
2563            return SavedStores(self._saved_datasets_path())
2564        except:
2565            self.load_metadata_backup()
2566            return SavedStores(self._saved_datasets_path())
2567
2568    def _prediction(self, name: str) -> Run:
2569        """
2570        Get a prediction meta object
2571
2572        Parameters
2573        ----------
2574        name : str
2575            episode name
2576
2577        Returns
2578        -------
2579        prediction : Run
2580            the prediction meta object
2581        """
2582
2583        try:
2584            return Run(name, self.project_path, meta_path=self._predictions_path())
2585        except:
2586            self.load_metadata_backup()
2587            return Run(name, self.project_path, meta_path=self._predictions_path())
2588
2589    def _episode(self, name: str) -> Run:
2590        """
2591        Get an episode meta object
2592
2593        Parameters
2594        ----------
2595        name : str
2596            episode name
2597
2598        Returns
2599        -------
2600        episode : Run
2601            the episode meta object
2602        """
2603
2604        try:
2605            return Run(name, self.project_path, meta_path=self._episodes_path())
2606        except:
2607            self.load_metadata_backup()
2608            return Run(name, self.project_path, meta_path=self._episodes_path())
2609
2610    def _searches(self) -> Searches:
2611        """
2612        Get the hyperparameter search meta object
2613
2614        Returns
2615        -------
2616        searches : Searches
2617            the searches meta object
2618        """
2619
2620        try:
2621            return Searches(self._searches_path(), self.project_path)
2622        except:
2623            self.load_metadata_backup()
2624            return Searches(self._searches_path(), self.project_path)
2625
2626    def _update_configs(self) -> None:
2627        """
2628        Update the project config files with newly added files and parameters
2629        """
2630
2631        self.update_parameters({"data": {"data_path": self.data_path}})
2632        folders = ["augmentations", "features", "model"]
2633        original_path = os.path.join(
2634            os.path.dirname(os.path.dirname(__file__)), "config"
2635        )
2636        project_path = os.path.join(self.project_path, "config")
2637        filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")]
2638        for folder in folders:
2639            filenames += [
2640                os.path.join(folder, x)
2641                for x in os.listdir(os.path.join(original_path, folder))
2642            ]
2643        filenames.append(os.path.join("data", f"{self.data_type}.yaml"))
2644        if self.annotation_type != "none":
2645            filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml"))
2646        for file in filenames:
2647            filepath_original = os.path.join(original_path, file)
2648            if file.startswith("data") or file.startswith("annotation"):
2649                file = os.path.basename(file)
2650            filepath_project = os.path.join(project_path, file)
2651            if not os.path.exists(filepath_project):
2652                shutil.copy(filepath_original, filepath_project)
2653            else:
2654                original_pars = self._open_yaml(filepath_original)
2655                project_pars = self._open_yaml(filepath_project)
2656                to_remove = []
2657                for key, value in project_pars.items():
2658                    if key not in original_pars:
2659                        if key not in ["data_type", "annotation_type"]:
2660                            to_remove.append(key)
2661                for key in to_remove:
2662                    project_pars.pop(key)
2663                to_remove = []
2664                for key, value in original_pars.items():
2665                    if key in project_pars:
2666                        to_remove.append(key)
2667                for key in to_remove:
2668                    original_pars.pop(key)
2669                project_pars = self._update(project_pars, original_pars)
2670                with open(filepath_project, "w", encoding="utf-8") as f:
2671                    YAML().dump(project_pars, f)
2672
2673    def _update_project(self) -> None:
2674        """
2675        Update project files with the current version
2676        """
2677
2678        version_file = self._version_path()
2679        ok = True
2680        if not os.path.exists(version_file):
2681            ok = False
2682        else:
2683            with open(version_file) as f:
2684                project_version = f.read()
2685            if project_version < __version__:
2686                ok = False
2687            elif project_version > __version__:
2688                warnings.warn(
2689                    f"The project expects a higher dlc2action version ({project_version}), please update!"
2690                )
2691        if not ok:
2692            project_config_path = os.path.join(self.project_path, "config")
2693            config_path = os.path.join(
2694                os.path.dirname(os.path.dirname(__path__)), "config"
2695            )
2696            episodes = self._episodes()
2697            folders = ["annotation", "augmentations", "data", "features", "model"]
2698
2699            project_annotation_configs = os.listdir(
2700                os.path.join(project_config_path, "annotation")
2701            )
2702            annotation_configs = os.listdir(os.path.join(config_path, "annotation"))
2703            for ann_config in annotation_configs:
2704                if ann_config not in project_annotation_configs:
2705                    shutil.copytree(
2706                        os.path.join(config_path, "annotation", ann_config),
2707                        os.path.join(project_config_path, "annotation", ann_config),
2708                        dirs_exist_ok=True,
2709                    )
2710                else:
2711                    project_pars = self._open_yaml(
2712                        os.path.join(project_config_path, "annotation", ann_config)
2713                    )
2714                    pars = self._open_yaml(
2715                        os.path.join(config_path, "annotation", ann_config)
2716                    )
2717                    new_keys = set(pars.keys()) - set(project_pars.keys())
2718                    for key in new_keys:
2719                        project_pars[key] = pars[key]
2720                        c = self._get_comment(pars.ca.items.get(key))
2721                        project_pars.yaml_add_eol_comment(c, key=key)
2722                        episodes.update(
2723                            condition=f"general/annotation_type::={ann_config}",
2724                            update={f"data/{key}": pars[key]},
2725                        )
2726
2727    def _initialize_project(
2728        self,
2729        data_type: str,
2730        annotation_type: str = None,
2731        data_path: str = None,
2732        annotation_path: str = None,
2733        copy: bool = True,
2734    ) -> None:
2735        """
2736        Initialize a new project
2737        """
2738
2739        if data_type not in self.data_types():
2740            raise ValueError(
2741                f"The {data_type} data type is not available. "
2742                f"Please choose from {self.data_types()}"
2743            )
2744        if annotation_type not in self.annotation_types():
2745            raise ValueError(
2746                f"The {annotation_type} annotation type is not available. "
2747                f"Please choose from {self.annotation_types()}"
2748            )
2749        os.mkdir(self.project_path)
2750        folders = ["results", "saved_datasets", "meta", "config"]
2751        for f in folders:
2752            os.mkdir(os.path.join(self.project_path, f))
2753        results_subfolders = [
2754            "model",
2755            "logs",
2756            "predictions",
2757            "splits",
2758            "searches",
2759        ]
2760        for sf in results_subfolders:
2761            os.mkdir(os.path.join(self.project_path, "results", sf))
2762        if data_path is not None:
2763            if copy:
2764                os.mkdir(os.path.join(self.project_path, "data"))
2765                shutil.copytree(
2766                    data_path,
2767                    os.path.join(self.project_path, "data"),
2768                    dirs_exist_ok=True,
2769                )
2770                data_path = os.path.join(self.project_path, "data")
2771        if annotation_path is not None:
2772            if copy:
2773                os.mkdir(os.path.join(self.project_path, "annotation"))
2774                shutil.copytree(
2775                    annotation_path,
2776                    os.path.join(self.project_path, "annotation"),
2777                    dirs_exist_ok=True,
2778                )
2779                annotation_path = os.path.join(self.project_path, "annotation")
2780        self._generate_config(
2781            data_type,
2782            annotation_type,
2783            data_path=data_path,
2784            annotation_path=annotation_path,
2785        )
2786        self._generate_meta()
2787
2788    def _read_types(self) -> Tuple[str, str]:
2789        """
2790        Get data type and annotation type from existing project files
2791        """
2792
2793        config_path = os.path.join(self.project_path, "config", "general.yaml")
2794        with open(config_path) as f:
2795            pars = YAML().load(f)
2796        data_type = pars["data_type"]
2797        annotation_type = pars["annotation_type"]
2798        return annotation_type, data_type
2799
2800    def _read_paths(self) -> Tuple[str, str]:
2801        """
2802        Get data type and annotation type from existing project files
2803        """
2804
2805        config_path = os.path.join(self.project_path, "config", "data.yaml")
2806        with open(config_path) as f:
2807            pars = YAML().load(f)
2808        data_path = pars["data_path"]
2809        annotation_path = pars["annotation_path"]
2810        return annotation_path, data_path
2811
2812    def _generate_config(
2813        self, data_type: str, annotation_type: str, data_path: str, annotation_path: str
2814    ) -> None:
2815        """
2816        Initialize the config files
2817        """
2818
2819        default_path = os.path.join(
2820            os.path.dirname(os.path.dirname(__file__)), "config"
2821        )
2822        config_path = os.path.join(self.project_path, "config")
2823        files = ["losses", "metrics", "ssl", "training"]
2824        for f in files:
2825            shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path)
2826        shutil.copytree(
2827            os.path.join(default_path, "model"), os.path.join(config_path, "model")
2828        )
2829        shutil.copytree(
2830            os.path.join(default_path, "features"),
2831            os.path.join(config_path, "features"),
2832        )
2833        shutil.copytree(
2834            os.path.join(default_path, "augmentations"),
2835            os.path.join(config_path, "augmentations"),
2836        )
2837        yaml = YAML()
2838        data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml")
2839        if os.path.exists(data_param_path):
2840            with open(data_param_path, encoding="utf-8") as f:
2841                data_params = yaml.load(f)
2842        if data_params is None:
2843            data_params = {}
2844        if annotation_type is None:
2845            ann_params = {}
2846        else:
2847            ann_param_path = os.path.join(
2848                default_path, "annotation", f"{annotation_type}.yaml"
2849            )
2850            if os.path.exists(ann_param_path):
2851                ann_params = self._open_yaml(ann_param_path)
2852            elif annotation_type == "none":
2853                ann_params = {}
2854            else:
2855                raise ValueError(
2856                    f"The {annotation_type} data type is not available. "
2857                    f"Please choose from {BehaviorDataset.annotation_types()}"
2858                )
2859        if ann_params is None:
2860            ann_params = {}
2861        data_params = self._update(data_params, ann_params)
2862        data_params["data_path"] = data_path
2863        data_params["annotation_path"] = annotation_path
2864        with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f:
2865            yaml.dump(data_params, f)
2866        with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f:
2867            general_params = yaml.load(f)
2868        general_params["data_type"] = data_type
2869        general_params["annotation_type"] = annotation_type
2870        with open(os.path.join(config_path, "general.yaml"), "w", encoding="utf-8") as f:
2871            yaml.dump(general_params, f)
2872
2873    def _generate_meta(self) -> None:
2874        """
2875        Initialize the meta files
2876        """
2877
2878        config_file = os.path.join(self.project_path, "config")
2879        meta_fields = ["time"]
2880        columns = [("meta", field) for field in meta_fields]
2881        episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
2882        episodes.to_pickle(self._episodes_path())
2883        meta_fields = ["time", "objective"]
2884        result_fields = ["best_params", "best_value"]
2885        columns = [("meta", field) for field in meta_fields] + [
2886            ("results", field) for field in result_fields
2887        ]
2888        searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
2889        searches.to_pickle(self._searches_path())
2890        meta_fields = ["time"]
2891        columns = [("meta", field) for field in meta_fields]
2892        predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
2893        predictions.to_pickle(self._predictions_path())
2894        with open(os.path.join(config_file, "data.yaml")) as f:
2895            data_keys = list(YAML().load(f).keys())
2896        saved_data = pd.DataFrame(columns=data_keys)
2897        saved_data.to_pickle(self._saved_datasets_path())
2898        pd.DataFrame().to_pickle(self._thresholds_path())
2899        # with open(self._version_path()) as f:
2900        #     f.write(__version__)
2901
2902    def _open_yaml(self, path: str) -> CommentedMap:
2903        """
2904        Load a parameter dictionary from a .yaml file
2905        """
2906
2907        with open(path, encoding="utf-8") as f:
2908            data = YAML().load(f)
2909        if data is None:
2910            data = {}
2911        return data
2912
2913    def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7):
2914        """
2915        Compare nested dictionaries with 'almost equal' condition
2916        """
2917
2918        ok = True
2919        if u.keys() != d.keys():
2920            ok = False
2921        else:
2922            for k, v in u.items():
2923                if isinstance(v, Mapping):
2924                    ok = self._compare(d[k], v, allow_diff=allow_diff)
2925                else:
2926                    if isinstance(v, float) or isinstance(d[k], float):
2927                        if not isinstance(d[k], float) and not isinstance(d[k], int):
2928                            ok = False
2929                        elif not isinstance(v, float) and not isinstance(v, int):
2930                            ok = False
2931                        elif np.abs(v - d[k]) > allow_diff:
2932                            ok = False
2933                    elif v != d[k]:
2934                        ok = False
2935        return ok
2936
2937    def _check_comment(self, comment_sequence: List) -> bool:
2938        """
2939        Check if a comment already exists in a ruamel.yaml comment sequence
2940        """
2941
2942        if comment_sequence is None:
2943            return False
2944        c = self._get_comment(comment_sequence)
2945        if c != "":
2946            return True
2947        else:
2948            return False
2949
2950    def _get_comment(self, comment_sequence: List, strip=True) -> str:
2951        """
2952        Get the comment string from a ruamel.yaml comment sequence
2953        """
2954
2955        if comment_sequence is None:
2956            return ""
2957        c = ""
2958        for cm in comment_sequence:
2959            if cm is not None:
2960                if isinstance(cm, Iterable):
2961                    for c in cm:
2962                        if c is not None:
2963                            c = c.value
2964                            break
2965                    break
2966                else:
2967                    c = cm.value
2968                    break
2969        if strip:
2970            c = c.strip()
2971        return c
2972
2973    def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]):
2974        """
2975        Update a nested dictionary
2976        """
2977
2978        if "general" in u and "model_name" in u["general"] and "model" in d:
2979            model_name = u["general"]["model_name"]
2980            if d["general"]["model_name"] != model_name:
2981                d["model"] = self._open_yaml(
2982                    os.path.join(
2983                        self.project_path, "config", "model", f"{model_name}.yaml"
2984                    )
2985                )
2986        d_copied = deepcopy(d)
2987        for k, v in u.items():
2988            if (
2989                k in d_copied
2990                and isinstance(d_copied[k], list)
2991                and isinstance(v, Mapping)
2992                and all([isinstance(x, int) for x in v.keys()])
2993            ):
2994                for kk, vv in v.items():
2995                    d_copied[k][kk] = vv
2996            elif (
2997                isinstance(v, Mapping)
2998                and k in d_copied
2999                and isinstance(d_copied[k], Mapping)
3000            ):
3001                if d_copied[k] is None:
3002                    d_k = CommentedMap()
3003                else:
3004                    d_k = d_copied[k]
3005                d_copied[k] = self._update(d_k, v)
3006            else:
3007                d_copied[k] = v
3008                if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None:
3009                    c = self._get_comment(u.ca.items.get(k), strip=False)
3010                    if isinstance(d_copied, CommentedMap) and not self._check_comment(
3011                        d_copied.ca.items.get(k)
3012                    ):
3013                        d_copied.yaml_add_eol_comment(c, key=k)
3014        return d_copied
3015
3016    def _update_with_search(
3017        self,
3018        d: Dict,
3019        search_name: str,
3020        load_parameters: list = None,
3021        round_to_binary: list = None,
3022    ):
3023        """
3024        Update a dictionary with best parameters from a hyperparameter search
3025        """
3026
3027        u, _ = self._searches().get_best_params(
3028            search_name, load_parameters, round_to_binary
3029        )
3030        return self._update(d, u)
3031
3032    def _read_parameters(self, catch_blanks=True) -> Dict:
3033        """
3034        Compose a parameter dictionary to create a task from the config files
3035        """
3036
3037        config_path = os.path.join(self.project_path, "config")
3038        keys = [
3039            "data",
3040            "general",
3041            "losses",
3042            "metrics",
3043            "ssl",
3044            "training",
3045        ]
3046        parameters = {}
3047        for key in keys:
3048            parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml"))
3049        features = parameters["general"]["feature_extraction"]
3050        parameters["features"] = self._open_yaml(
3051            os.path.join(config_path, "features", f"{features}.yaml")
3052        )
3053        transformer = options.extractor_to_transformer[features]
3054        parameters["augmentations"] = self._open_yaml(
3055            os.path.join(config_path, "augmentations", f"{transformer}.yaml")
3056        )
3057        model = parameters["general"]["model_name"]
3058        parameters["model"] = self._open_yaml(
3059            os.path.join(config_path, "model", f"{model}.yaml")
3060        )
3061        # input = parameters["general"]["input"]
3062        # parameters["model"] = self._open_yaml(
3063        #     os.path.join(config_path, "model", f"{model}.yaml")
3064        # )
3065        if catch_blanks:
3066            blanks = self._get_blanks()
3067            if len(blanks) > 0:
3068                self.list_blanks()
3069                raise ValueError(
3070                    f"Please fill in all the blanks before running experiments"
3071                )
3072        return parameters
3073
3074    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3075        """
3076        Select the model and the metrics
3077
3078        Parameters
3079        ----------
3080        model_name : str, optional
3081            model name; run `project.help("model") to find out more
3082        metric_names : list, optional
3083            a list of metric function names; run `project.help("metrics") to find out more
3084        """
3085
3086        pars = {"general": {}}
3087        if model_name is not None:
3088            assert model_name in options.models
3089            pars["general"]["model_name"] = model_name
3090        if metric_names is not None:
3091            for metric in metric_names:
3092                assert metric in options.metrics
3093            pars["general"]["metric_functions"] = metric_names
3094        self.update_parameters(pars)
3095
3096    def help(self, keyword: str = None):
3097        """
3098        Get information on available options
3099
3100        Parameters
3101        ----------
3102        keyword : str, optional
3103            the keyword for options (run without arguments to see which keywords are available)
3104
3105        """
3106
3107        if keyword is None:
3108            print("AVAILABLE HELP FUNCTIONS:")
3109            print("- Try running `project.help(keyword)` with the following keywords:")
3110            print("    - model: to get more information on available models,")
3111            print(
3112                "    - features: to get more information on available feature extraction modes,"
3113            )
3114            print(
3115                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3116            )
3117            print("    - metrics: to see a list of available metric functions.")
3118            print("    - data: to see help for expected data structure")
3119            print(
3120                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
3121            )
3122            print(
3123                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
3124            )
3125            print(
3126                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
3127            )
3128        elif keyword == "model":
3129            print("MODELS:")
3130            for key, model in options.models.items():
3131                print(f"{key}:")
3132                print(model.__doc__)
3133        elif keyword == "features":
3134            print("FEATURE EXTRACTORS:")
3135            for key, extractor in options.feature_extractors.items():
3136                print(f"{key}:")
3137                print(extractor.__doc__)
3138        elif keyword == "partition_method":
3139            print("PARTITION METHODS:")
3140            print(
3141                BehaviorDataset.partition_train_test_val.__doc__.split(
3142                    "The partitioning method:"
3143                )[1].split("val_frac :")[0]
3144            )
3145        elif keyword == "metrics":
3146            print("METRICS:")
3147            for key, metric in options.metrics.items():
3148                print(f"{key}:")
3149                print(metric.__doc__)
3150        elif keyword == "data":
3151            print("DATA:")
3152            print(f"Video data: {self.data_type}")
3153            print(options.input_stores[self.data_type].__doc__)
3154            print(f"Annotation data: {self.annotation_type}")
3155            print(options.annotation_stores[self.annotation_type].__doc__)
3156            print(
3157                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
3158            )
3159        else:
3160            raise ValueError(f"The {keyword} keyword is not recognized")
3161        print("\n")
3162
3163    def _process_value(self, value):
3164        if isinstance(value, str):
3165            value = f'"{value}"'
3166        elif isinstance(value, CommentedSet):
3167            value = {x for x in value}
3168        return value
3169
3170    def _get_blanks(self):
3171        caught = []
3172        parameters = self._read_parameters(catch_blanks=False)
3173        for big_key, big_value in parameters.items():
3174            for key, value in big_value.items():
3175                if value == "???":
3176                    caught.append(
3177                        (big_key, key, self._get_comment(big_value.ca.items.get(key)))
3178                    )
3179        return caught
3180
3181    def list_blanks(self, blanks=None):
3182        """
3183        List parameters that need to be filled in
3184
3185        Parameters
3186        ----------
3187        blanks : list, optional
3188            a list of the parameters to list, if already known
3189        """
3190
3191        if blanks is None:
3192            blanks = self._get_blanks()
3193        if len(blanks) > 0:
3194            to_update = defaultdict(lambda: [])
3195            for b, k, c in blanks:
3196                to_update[b].append((k, c))
3197            print("Before running experiments, please update all the blanks.")
3198            print("To do that, you can run this.")
3199            print("--------------------------------------------------------")
3200            print(f"project.update_parameters(")
3201            print(f"    {{")
3202            for big_key, keys in to_update.items():
3203                print(f'        "{big_key}": {{')
3204                for key, comment in keys:
3205                    print(f'            "{key}": ..., {comment}')
3206                print(f"        }},")
3207            print(f"    }}")
3208            print(")")
3209            print("--------------------------------------------------------")
3210            print("Replace ... with relevant values.")
3211        else:
3212            print("There is no blanks left!")
3213
3214    def list_basic_parameters(
3215        self,
3216    ):
3217        """
3218        Get a list of most relevant parameters and code to modify them
3219        """
3220
3221        parameters = self._read_parameters()
3222        print("BASIC PARAMETERS:")
3223        model_name = parameters["general"]["model_name"]
3224        metric_names = parameters["general"]["metric_functions"]
3225        loss_name = parameters["general"]["loss_function"]
3226        feature_extraction = parameters["general"]["feature_extraction"]
3227        print("Here is a list of current parameters.")
3228        print(
3229            "You can copy this code, change the parameters you want to set and run it to update the project config."
3230        )
3231        print("--------------------------------------------------------")
3232        print("project.update_parameters(")
3233        print("    {")
3234        for group in ["general", "data", "training"]:
3235            print(f'        "{group}": {{')
3236            for key in options.basic_parameters[group]:
3237                if key in parameters[group]:
3238                    print(
3239                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
3240                    )
3241            print("        },")
3242        print('        "losses": {')
3243        print(f'            "{loss_name}": {{')
3244        for key in options.basic_parameters["losses"][loss_name]:
3245            if key in parameters["losses"][loss_name]:
3246                print(
3247                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
3248                )
3249        print("            },")
3250        print("        },")
3251        print('        "metrics": {')
3252        for metric in metric_names:
3253            print(f'            "{metric}": {{')
3254            for key in parameters["metrics"][metric]:
3255                print(
3256                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
3257                )
3258            print("            },")
3259        print("        },")
3260        print('        "model": {')
3261        for key in options.basic_parameters["model"][model_name]:
3262            if key in parameters["model"]:
3263                print(
3264                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
3265                )
3266
3267        print("        },")
3268        print('        "features": {')
3269        for key in options.basic_parameters["features"][feature_extraction]:
3270            if key in parameters["features"]:
3271                print(
3272                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
3273                )
3274
3275        print("        },")
3276        print('        "augmentations": {')
3277        for key in options.basic_parameters["augmentations"][feature_extraction]:
3278            if key in parameters["augmentations"]:
3279                print(
3280                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
3281                )
3282        print("        },")
3283        print("    },")
3284        print(")")
3285        print("--------------------------------------------------------")
3286        print("\n")
3287
3288    def _create_record(
3289        self,
3290        episode_name: str,
3291        behaviors_dict: Dict,
3292        load_episode: str = None,
3293        parameters_update: Dict = None,
3294        task: TaskDispatcher = None,
3295        load_epoch: int = None,
3296        load_search: str = None,
3297        load_parameters: list = None,
3298        round_to_binary: list = None,
3299        load_strict: bool = True,
3300        n_seeds: int = 1,
3301    ) -> TaskDispatcher:
3302        """
3303        Create a meta data episode record
3304        """
3305
3306        if episode_name in self._episodes().data.index:
3307            return
3308        if type(n_seeds) is not int or n_seeds < 1:
3309            raise ValueError(
3310                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
3311            )
3312        if parameters_update is None:
3313            parameters_update = {}
3314        parameters = self._read_parameters()
3315        parameters = self._update(parameters, parameters_update)
3316        if load_search is not None:
3317            parameters = self._update_with_search(
3318                parameters, load_search, load_parameters, round_to_binary
3319            )
3320        parameters = self._fill(
3321            parameters,
3322            episode_name,
3323            load_episode,
3324            load_epoch=load_epoch,
3325            only_load_model=True,
3326            load_strict=load_strict,
3327            continuing=True,
3328        )
3329        self._save_episode(episode_name, parameters, behaviors_dict)
3330        return task
3331
3332    def _save_thresholds(
3333        self,
3334        episode_names: List,
3335        metric_name: str,
3336        parameters: Dict,
3337        thresholds: List,
3338        load_epochs: List,
3339    ):
3340        """
3341        Save optimal decision thresholds in the meta records
3342        """
3343
3344        metric_parameters = parameters["metrics"][metric_name]
3345        self._thresholds().save_thresholds(
3346            episode_names, load_epochs, metric_name, metric_parameters, thresholds
3347        )
3348
3349    def _save_episode(
3350        self,
3351        episode_name: str,
3352        parameters: Dict,
3353        behaviors_dict: Dict,
3354        suppress_validation: bool = False,
3355        training_time: str = None,
3356        norm_stats: Dict = None,
3357    ) -> None:
3358        """
3359        Save an episode in the meta files
3360        """
3361
3362        try:
3363            split_info = self._split_info_from_filename(
3364                parameters["training"]["split_path"]
3365            )
3366            parameters["training"]["partition_method"] = split_info["partition_method"]
3367        except:
3368            pass
3369        if norm_stats is not None:
3370            norm_stats = dict(norm_stats)
3371        parameters["training"]["stats"] = norm_stats
3372        self._episodes().save_episode(
3373            episode_name,
3374            parameters,
3375            behaviors_dict,
3376            suppress_validation=suppress_validation,
3377            training_time=training_time,
3378        )
3379
3380    def _update_episode_results(
3381        self,
3382        episode_name: str,
3383        logs: Tuple,
3384        training_time: str = None,
3385    ) -> None:
3386        """
3387        Save the results of a run in the meta files
3388        """
3389
3390        self._episodes().update_episode_results(episode_name, logs, training_time)
3391
3392    def _save_prediction(
3393        self,
3394        episode_name: str,
3395        parameters: Dict,
3396        behaviors_dict: Dict,
3397        embedding: bool = False,
3398        inference_time: str = None,
3399    ) -> None:
3400        """
3401        Save a prediction in the meta files
3402        """
3403
3404        parameters = self._update(
3405            parameters,
3406            {"meta": {"embedding": embedding, "inference_time": inference_time}},
3407        )
3408        self._predictions().save_episode(episode_name, parameters, behaviors_dict)
3409
3410    def _save_search(
3411        self,
3412        search_name: str,
3413        parameters: Dict,
3414        n_trials: int,
3415        best_params: Dict,
3416        best_value: float,
3417        metric: str,
3418        search_space: Dict,
3419    ) -> None:
3420        """
3421        Save a hyperparameter search in the meta files
3422        """
3423
3424        self._searches().save_search(
3425            search_name,
3426            parameters,
3427            n_trials,
3428            best_params,
3429            best_value,
3430            metric,
3431            search_space,
3432        )
3433
3434    def _save_stores(self, parameters: Dict) -> None:
3435        """
3436        Save a pickled dataset in the meta files
3437        """
3438
3439        name = os.path.basename(parameters["data"]["feature_save_path"])
3440        self._saved_datasets().save_store(name, self._get_data_pars(parameters))
3441        self.create_metadata_backup()
3442
3443    def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None:
3444        """
3445        Remove the pre-computed features folder
3446        """
3447
3448        name = os.path.basename(parameters["data"]["feature_save_path"])
3449        if remove_active or name not in self._episodes().get_active_datasets():
3450            self.remove_saved_features([name])
3451
3452    def _check_episode_validity(
3453        self, episode_name: str, allow_doublecolon: bool = False, force: bool = False
3454    ) -> None:
3455        """
3456        Check whether the episode name is valid
3457        """
3458
3459        if episode_name.startswith("_"):
3460            raise ValueError(
3461                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
3462            )
3463        elif "." in episode_name:
3464            raise ValueError("Names containing '.' cannot be used!")
3465        if not allow_doublecolon and "::" in episode_name:
3466            raise ValueError(
3467                "Names containing '::' are reserved by dlc2action and cannot be used!"
3468            )
3469        if force:
3470            self.remove_episode(episode_name)
3471        elif not self._episodes().check_name_validity(episode_name):
3472            raise ValueError(
3473                f"The {episode_name} name is already taken! Set force=True to overwrite."
3474            )
3475
3476    def _check_search_validity(self, search_name: str, force: bool = False) -> None:
3477        """
3478        Check whether the search name is valid
3479        """
3480
3481        if search_name.startswith("_"):
3482            raise ValueError(
3483                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
3484            )
3485        elif "." in search_name:
3486            raise ValueError("Names containing '.' cannot be used!")
3487        if force:
3488            self.remove_search(search_name)
3489        elif not self._searches().check_name_validity(search_name):
3490            raise ValueError(f"The {search_name} name is already taken!")
3491
3492    def _check_prediction_validity(
3493        self, prediction_name: str, force: bool = False
3494    ) -> None:
3495        """
3496        Check whether the prediction name is valid
3497        """
3498
3499        if prediction_name.startswith("_"):
3500            raise ValueError(
3501                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
3502            )
3503        elif "." in prediction_name:
3504            raise ValueError("Names containing '.' cannot be used!")
3505        if force:
3506            self.remove_prediction(prediction_name)
3507        elif not self._predictions().check_name_validity(prediction_name):
3508            raise ValueError(f"The {prediction_name} name is already taken!")
3509
3510    def _training_time(self, episode_name: str) -> int:
3511        """
3512        Get the training time of an episode in seconds
3513        """
3514
3515        return self._episode(episode_name).training_time()
3516
3517    def _mask_path(self) -> str:
3518        """
3519        Get the path to the masks folder
3520        """
3521
3522        return os.path.join(self.project_path, "results", "masks")
3523
3524    def _thresholds_path(self) -> str:
3525        """
3526        Get the path to the thresholds meta file
3527        """
3528
3529        return os.path.join(self.project_path, "meta", "thresholds.pickle")
3530
3531    def _episodes_path(self) -> str:
3532        """
3533        Get the path to the episodes meta file
3534        """
3535
3536        return os.path.join(self.project_path, "meta", "episodes.pickle")
3537
3538    def _saved_datasets_path(self) -> str:
3539        """
3540        Get the path to the datasets meta file
3541        """
3542
3543        return os.path.join(self.project_path, "meta", "saved_datasets.pickle")
3544
3545    def _predictions_path(self) -> str:
3546        """
3547        Get the path to the predictions meta file
3548        """
3549
3550        return os.path.join(self.project_path, "meta", "predictions.pickle")
3551
3552    def _dataset_store_path(self, name: str) -> str:
3553        """
3554        Get the path to a specific pickled dataset
3555        """
3556
3557        return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle")
3558
3559    def _searches_path(self) -> str:
3560        """
3561        Get the path to the hyperparameter search meta file
3562        """
3563
3564        return os.path.join(self.project_path, "meta", "searches.pickle")
3565
3566    def _search_path(self, name: str) -> str:
3567        """
3568        Get the default path to the graph folder for a specific hyperparameter search
3569        """
3570
3571        return os.path.join(self.project_path, "results", "searches", name)
3572
3573    def _version_path(self) -> str:
3574        """
3575        Get the path to the version file
3576        """
3577
3578        return os.path.join(self.project_path, "meta", "version.txt")
3579
3580    def _default_split_file(self, split_info: Dict) -> Optional[str]:
3581        """
3582        Generate a path to a split file from split parameters
3583        """
3584
3585        if split_info["partition_method"].startswith("time"):
3586            return None
3587        val_frac = split_info["val_frac"]
3588        test_frac = split_info["test_frac"]
3589        split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}'
3590        if not split_info["only_load_annotated"]:
3591            split_name += "_all"
3592        split_name += ".txt"
3593        return os.path.join(self.project_path, "results", "splits", split_name)
3594
3595    def _split_info_from_filename(self, split_name: str) -> Dict:
3596        """
3597        Get split parameters from default path to a split file
3598        """
3599
3600        if split_name is None:
3601            return {}
3602        try:
3603            name = os.path.basename(split_name)[:-4]
3604            split = name.split("_")
3605            if len(split) == 6:
3606                only_load_annotated = False
3607            else:
3608                only_load_annotated = True
3609            len_segment = int(split[3][3:])
3610            overlap = int(split[4][7:])
3611            method, val, test = split[:3]
3612            val = float(val[3:-1]) / 100
3613            test = float(test[4:-1]) / 100
3614            return {
3615                "partition_method": method,
3616                "val_frac": val,
3617                "test_frac": test,
3618                "only_load_annotated": only_load_annotated,
3619                "len_segment": len_segment,
3620                "overlap": overlap,
3621            }
3622        except:
3623            return {"partition_method": "file"}
3624
3625    def _fill(
3626        self,
3627        parameters: Dict,
3628        episode_name: str,
3629        load_experiment: str = None,
3630        load_epoch: int = None,
3631        load_strict: bool = True,
3632        only_load_model: bool = False,
3633        continuing: bool = False,
3634        enforce_split_parameters: bool = False,
3635    ) -> Dict:
3636        """
3637        Update the parameters from the config files with project specific information
3638
3639        Fill in the constant file path parameters and generate a unique log file and a model folder.
3640        Fill in the split file if the same split has been run before in the project and change partition method to
3641        from_file.
3642        Fill in saved data path if a dataset with the same data parameters already exists in the project.
3643        If load_experiment is not None, fill in the checkpoint path as well.
3644        The only_load_model training parameter is defined by the corresponding argument.
3645        If continuing is True, new files are not created and all information is loaded from load_experiment.
3646        If prediction is True, log and model files are not created.
3647        The enforce_split_parameters parameter is used to resolve conflicts
3648        between split file path and split parameters when they arise.
3649        """
3650
3651        pars = deepcopy(parameters)
3652        if episode_name == "_":
3653            self.remove_episode("_")
3654        log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt")
3655        model_save_path = os.path.join(
3656            self.project_path, "results", "model", episode_name
3657        )
3658        if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)):
3659            raise ValueError(
3660                f"The {episode_name} episode name is already in use! Set force=True to overwrite."
3661            )
3662        keys = ["val_frac", "test_frac", "partition_method"]
3663        if "len_segment" not in pars["general"] and "len_segment" in pars["data"]:
3664            pars["general"]["len_segment"] = pars["data"]["len_segment"]
3665        if "overlap" not in pars["general"] and "overlap" in pars["data"]:
3666            pars["general"]["overlap"] = pars["data"]["overlap"]
3667        if "len_segment" in pars["data"]:
3668            pars["data"].pop("len_segment")
3669        if "overlap" in pars["data"]:
3670            pars["data"].pop("overlap")
3671        split_info = {k: pars["training"][k] for k in keys}
3672        split_info["only_load_annotated"] = pars["general"]["only_load_annotated"]
3673        split_info["len_segment"] = pars["general"]["len_segment"]
3674        split_info["overlap"] = pars["general"]["overlap"]
3675        pars["training"]["log_file"] = log
3676        if not os.path.exists(model_save_path):
3677            os.mkdir(model_save_path)
3678        pars["training"]["model_save_path"] = model_save_path
3679        if load_experiment is not None:
3680            if load_experiment not in self._episodes().data.index:
3681                raise ValueError(f"The {load_experiment} episode does not exist!")
3682            old_episode = self._episode(load_experiment)
3683            old_file = old_episode.split_file()
3684            old_info = self._split_info_from_filename(old_file)
3685            if len(old_info) == 0:
3686                old_info = old_episode.split_info()
3687            if enforce_split_parameters:
3688                if split_info["partition_method"] != "file":
3689                    pars["training"]["split_path"] = self._default_split_file(
3690                        split_info
3691                    )
3692            else:
3693                equal = True
3694                if old_info["partition_method"] != split_info["partition_method"]:
3695                    equal = False
3696                if old_info["partition_method"] != "file":
3697                    if (
3698                        old_info["val_frac"] != split_info["val_frac"]
3699                        or old_info["test_frac"] != split_info["test_frac"]
3700                    ):
3701                        equal = False
3702                if not continuing and not equal:
3703                    warnings.warn(
3704                        f"The partitioning parameters in the loaded experiment ({old_info}) "
3705                        f"are not equal to the current partitioning parameters ({split_info}). "
3706                        f"The current parameters are replaced."
3707                    )
3708                pars["training"]["split_path"] = old_file
3709            pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch)
3710            pars["training"]["load_strict"] = load_strict
3711        else:
3712            pars["training"]["checkpoint_path"] = None
3713            if pars["training"]["partition_method"] == "file":
3714                if (
3715                    "split_path" not in pars["training"]
3716                    or pars["training"]["split_path"] is None
3717                ):
3718                    raise ValueError(
3719                        "The partition_method parameter is set to file but the "
3720                        "split_path parameter is not set!"
3721                    )
3722                elif not os.path.exists(pars["training"]["split_path"]):
3723                    raise ValueError(
3724                        f'The {pars["training"]["split_path"]} split file does not exist'
3725                    )
3726            else:
3727                pars["training"]["split_path"] = self._default_split_file(split_info)
3728        pars["training"]["only_load_model"] = only_load_model
3729        pars["data"]["saved_data_path"] = None
3730        pars["data"]["feature_save_path"] = None
3731        pars_data_copy = self._get_data_pars(pars)
3732        saved_data_name = self._saved_datasets().find_name(pars_data_copy)
3733        if saved_data_name is not None:
3734            pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name)
3735            pars["data"]["feature_save_path"] = self._dataset_store_path(
3736                saved_data_name
3737            ).split(".")[0]
3738        else:
3739            dataset_path = self._dataset_store_path(episode_name)
3740            if os.path.exists(dataset_path):
3741                name, ext = dataset_path.split(".")
3742                i = 0
3743                while os.path.exists(f"{name}_{i}.{ext}"):
3744                    i += 1
3745                dataset_path = f"{name}_{i}.{ext}"
3746            pars["data"]["saved_data_path"] = dataset_path
3747            pars["data"]["feature_save_path"] = dataset_path.split(".")[0]
3748        split_split = pars["training"]["partition_method"].split(":")
3749        random = True
3750        for partition_method in options.partition_methods["fixed"]:
3751            method_split = partition_method.split(":")
3752            if len(split_split) != len(method_split):
3753                continue
3754            equal = True
3755            for x, y in zip(split_split, method_split):
3756                if y.startswith("{"):
3757                    continue
3758                if x != y:
3759                    equal = False
3760                    break
3761            if equal:
3762                random = False
3763                break
3764        if random and os.path.exists(pars["training"]["split_path"]):
3765            pars["training"]["partition_method"] = "file"
3766        pars["general"]["save_dataset"] = True
3767        return pars
3768
3769    def _get_data_pars(self, pars: Dict) -> Dict:
3770        """
3771        Get a complete description of the data from a general parameters dictionary
3772        """
3773
3774        pars_data_copy = deepcopy(pars["data"])
3775        for par in [
3776            "only_load_annotated",
3777            "exclusive",
3778            "feature_extraction",
3779            "ignored_clips",
3780            "len_segment",
3781            "overlap",
3782        ]:
3783            pars_data_copy[par] = pars["general"].get(par, None)
3784        pars_data_copy.update(pars["features"])
3785        return pars_data_copy
3786
3787    def count_classes(
3788        self,
3789        load_episode: str = None,
3790        parameters_update: Dict = None,
3791        remove_saved_features: bool = False,
3792        bouts: bool = True,
3793    ) -> Dict:
3794        """
3795        Get a dictionary of class counts in different modes
3796
3797        Parameters
3798        ----------
3799        load_episode : str, optional
3800            the episode settings to load
3801        parameters_update : dict, optional
3802            a dictionary of parameter updates (only for "data" and "general" categories)
3803        remove_saved_features : bool, default False
3804            if `True`, the dataset that is used for computation is then deleted
3805        bouts : bool, default False
3806            if `True`, instead of frame counts segment counts are returned
3807
3808        Returns
3809        -------
3810        class_counts : dict
3811            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
3812            class names and values are class counts (in frames)
3813        """
3814
3815        if load_episode is None:
3816            task, parameters = self._make_task_training(
3817                episode_name="_", parameters_update=parameters_update, throwaway=True
3818            )
3819        else:
3820            task, parameters, _ = self._make_task_prediction(
3821                "_",
3822                load_episode=load_episode,
3823                parameters_update=parameters_update,
3824            )
3825        class_counts = task.count_classes(bouts=bouts)
3826        behaviors = task.behaviors_dict()
3827        class_counts = {
3828            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
3829            for kk, vv in class_counts.items()
3830        }
3831        if remove_saved_features:
3832            self._remove_stores(parameters)
3833        return class_counts
3834
3835    def plot_class_distribution(
3836        self,
3837        parameters_update: Dict = None,
3838        frame_cutoff: int = 1,
3839        bout_cutoff: int = 1,
3840        print_full: bool = False,
3841        remove_saved_features: bool = False,
3842    ) -> None:
3843        """
3844        Make a class distribution plot
3845
3846        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
3847        is created or laoded for the computation with the default parameters).
3848
3849        Parameters
3850        ----------
3851        parameters_update : dict, optional
3852            a dictionary of parameter updates (only for "data" and "general" categories)
3853        remove_saved_features : bool, default False
3854            if `True`, the dataset that is used for computation is then deleted
3855        """
3856
3857        task, parameters = self._make_task_training(
3858            episode_name="_", parameters_update=parameters_update, throwaway=True
3859        )
3860        cutoff = {True: bout_cutoff, False: frame_cutoff}
3861        for bouts in [True, False]:
3862            class_counts = task.count_classes(bouts=bouts)
3863            if print_full:
3864                print("Bouts:" if bouts else "Frames:")
3865                for k, v in class_counts.items():
3866                    if sum(v.values()) != 0:
3867                        print(f"  {k}:")
3868                        values, keys = zip(
3869                            *[
3870                                x
3871                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
3872                                if x[-1] != -100
3873                            ]
3874                        )
3875                        for kk, vv in zip(keys, values):
3876                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
3877            class_counts = {
3878                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
3879                for kk, vv in class_counts.items()
3880            }
3881            for key, d in class_counts.items():
3882                if sum(d.values()) != 0:
3883                    values, keys = zip(
3884                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
3885                    )
3886                    keys = [task.behaviors_dict()[x] for x in keys]
3887                    plt.bar(keys, values)
3888                    plt.title(key)
3889                    plt.xticks(rotation=45, ha="right")
3890                    if bouts:
3891                        plt.ylabel("bouts")
3892                    else:
3893                        plt.ylabel("frames")
3894                    plt.tight_layout()
3895                    plt.show()
3896        if remove_saved_features:
3897            self._remove_stores(parameters)
3898
3899    def _generate_mask(
3900        self,
3901        mask_name: str,
3902        perc_annotated: float = 0.1,
3903        parameters_update: Dict = None,
3904        remove_saved_features: bool = False,
3905    ) -> None:
3906        """
3907        Generate a real_lens for active learning simulation
3908
3909        Parameters
3910        ----------
3911        mask_name : str
3912            the name of the real_lens
3913        """
3914
3915        print(f"GENERATING {mask_name}")
3916        task, parameters = self._make_task_training(
3917            f"_{mask_name}", parameters_update=parameters_update, throwaway=True
3918        )
3919        val_intervals, val_ids = task.dataset("val").get_intervals()  # 1
3920        unannotated_intervals = task.dataset("train").get_unannotated_intervals()  # 2
3921        unannotated_intervals = task.dataset("val").get_unannotated_intervals(
3922            first_intervals=unannotated_intervals
3923        )
3924        ids = task.dataset("train").get_ids()
3925        mask = {video_id: {} for video_id in ids}
3926        total_all = 0
3927        total_masked = 0
3928        for video_id, clip_ids in ids.items():
3929            for clip_id in clip_ids:
3930                frames = np.ones(task.dataset("train").get_len(video_id, clip_id))
3931                if clip_id in val_intervals[video_id]:
3932                    for start, end in val_intervals[video_id][clip_id]:
3933                        frames[start:end] = 0
3934                if clip_id in unannotated_intervals[video_id]:
3935                    for start, end in unannotated_intervals[video_id][clip_id]:
3936                        frames[start:end] = 0
3937                annotated = np.where(frames)[0]
3938                total_all += len(annotated)
3939                masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :]
3940                total_masked += len(masked)
3941                mask[video_id][clip_id] = self._get_intervals(masked)
3942        file = {
3943            "masked": mask,
3944            "val_intervals": val_intervals,
3945            "val_ids": val_ids,
3946            "unannotated": unannotated_intervals,
3947        }
3948        self._save_mask(file, mask_name)
3949        if remove_saved_features:
3950            self._remove_stores(parameters)
3951        print("\n")
3952        # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames')
3953
3954    def _get_intervals(self, frame_indices: np.ndarray):
3955        """
3956        Get a list of intervals from a list of frame indices
3957
3958        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
3959
3960        Parameters
3961        ----------
3962        frame_indices : np.ndarray
3963            a list of frame indices
3964
3965        Returns
3966        -------
3967        intervals : list
3968            a list of interval boundaries
3969        """
3970
3971        masked_intervals = []
3972        if len(frame_indices) > 0:
3973            breaks = np.where(np.diff(frame_indices) != 1)[0]
3974            start = frame_indices[0]
3975            for k in breaks:
3976                masked_intervals.append([start, frame_indices[k] + 1])
3977                start = frame_indices[k + 1]
3978            masked_intervals.append([start, frame_indices[-1] + 1])
3979        return masked_intervals
3980
3981    def _update_mask_with_uncertainty(
3982        self,
3983        mask_name: str,
3984        episode_name: Union[str, None],
3985        classes: List,
3986        load_epoch: int = None,
3987        n_frames: int = 10000,
3988        method: str = "least_confidence",
3989        min_length: int = 30,
3990        augment_n: int = 0,
3991        parameters_update: Dict = None,
3992    ):
3993        """
3994        Update real_lens with frame-wise uncertainty scores for active learning
3995
3996        Parameters
3997        ----------
3998        mask_name : str
3999            the name of the real_lens
4000        episode_name : str
4001            the name of the episode to load
4002        classes : list
4003            a list of class names or indices; their uncertainty scores will be computed separately and stacked
4004        n_frames : int, default 10000
4005            the number of frames to "annotate"
4006        method : {"least_confidence", "entropy"}
4007            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
4008            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
4009        min_length : int
4010            the minimum length (in frames) of the annotated intervals
4011        augment_n : int, default 0
4012            the number of augmentations to average over
4013        parameters_update : dict, optional
4014            the dictionary used to update the parameters from the config
4015
4016        Returns
4017        -------
4018        score_dicts : dict
4019            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
4020            are score tensors
4021        """
4022
4023        print(f"UPDATING {mask_name}")
4024        task, parameters, _ = self._make_task_prediction(
4025            prediction_name=mask_name,
4026            load_episode=episode_name,
4027            parameters_update=parameters_update,
4028            load_epoch=load_epoch,
4029            mode="train",
4030        )
4031        score_tensors = task.generate_uncertainty_score(classes, augment_n, method)
4032        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
4033        print("\n")
4034
4035    def _update_mask_with_BALD(
4036        self,
4037        mask_name: str,
4038        episode_name: str,
4039        classes: List,
4040        load_epoch: int = None,
4041        augment_n: int = 0,
4042        n_frames: int = 10000,
4043        num_models: int = 10,
4044        kernel_size: int = 11,
4045        min_length: int = 30,
4046        parameters_update: Dict = None,
4047    ):
4048        """
4049        Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning
4050
4051        Parameters
4052        ----------
4053        mask_name : str
4054            the name of the real_lens
4055        episode_name : str
4056            the name of the episode to load
4057        classes : list
4058            a list of class names or indices; their uncertainty scores will be computed separately and stacked
4059        augment_n : int, default 0
4060            the number of augmentations to average over
4061        n_frames : int, default 10000
4062            the number of frames to "annotate"
4063        num_models : int, default 10
4064            the number of dropout masks to apply
4065        kernel_size : int, default 11
4066            the size of the smoothing gaussian kernel
4067        min_length : int
4068            the minimum length (in frames) of the annotated intervals
4069        parameters_update : dict, optional
4070            the dictionary used to update the parameters from the config
4071
4072        Returns
4073        -------
4074        score_dicts : dict
4075            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
4076            are score tensors
4077        """
4078
4079        print(f"UPDATING {mask_name}")
4080        task, parameters, mode = self._make_task_prediction(
4081            mask_name,
4082            load_episode=episode_name,
4083            parameters_update=parameters_update,
4084            load_epoch=load_epoch,
4085        )
4086        score_tensors = task.generate_bald_score(
4087            classes, augment_n, num_models, kernel_size
4088        )
4089        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
4090        print("\n")
4091
4092    def _suggest_intervals(
4093        self,
4094        dataset: BehaviorDataset,
4095        score_tensors: Dict,
4096        n_frames: int,
4097        min_length: int,
4098    ) -> Dict:
4099        """
4100        Suggest intervals with highest score of total length `n_frames`
4101
4102        Parameters
4103        ----------
4104        dataset : BehaviorDataset
4105            the dataset
4106        score_tensors : dict
4107            a dictionary where keys are clip ids and values are framewise score tensors
4108        n_frames : int
4109            the number of frames to "annotate"
4110        min_length : int
4111
4112        Returns
4113        -------
4114        active_learning_intervals : Dict
4115            active learning dictionary with suggested intervals
4116        """
4117
4118        video_intervals, _ = dataset.get_intervals()
4119        taken = {
4120            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
4121        }
4122        annotated = dataset.get_annotated_intervals()
4123        for video_id in video_intervals:
4124            for clip_id in video_intervals[video_id]:
4125                taken[video_id][clip_id] = torch.zeros(
4126                    dataset.get_len(video_id, clip_id)
4127                )
4128                if video_id in annotated and clip_id in annotated[video_id]:
4129                    for start, end in annotated[video_id][clip_id]:
4130                        score_tensors[video_id][clip_id][:, start:end] = -10
4131                        taken[video_id][clip_id][int(start) : int(end)] = 1
4132        n_frames = (
4133            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
4134            + n_frames
4135        )
4136        factor = 1
4137        threshold_start = float(
4138            torch.mean(
4139                torch.tensor(
4140                    [
4141                        torch.mean(
4142                            torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
4143                        )
4144                        for x in score_tensors.values()
4145                    ]
4146                )
4147            )
4148        )
4149        while (
4150            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
4151            < n_frames
4152        ):
4153            threshold = threshold_start * factor
4154            intervals = []
4155            interval_scores = []
4156            key1 = list(score_tensors.keys())[0]
4157            key2 = list(score_tensors[key1].keys())[0]
4158            num_scores = score_tensors[key1][key2].shape[0]
4159            for i in range(num_scores):
4160                v_dict = dataset.find_valleys(
4161                    predicted=score_tensors,
4162                    threshold=threshold,
4163                    min_frames=min_length,
4164                    main_class=i,
4165                    low=False,
4166                )
4167                for v_id, interval_list in v_dict.items():
4168                    intervals += [x + [v_id] for x in interval_list]
4169                    interval_scores += [
4170                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
4171                        for start, end, clip_id in interval_list
4172                    ]
4173            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
4174            i = 0
4175            while sum(
4176                [(vv == 1).sum() for v in taken.values() for vv in v.values()]
4177            ) < n_frames and i < len(intervals):
4178                start, end, clip_id, video_id = intervals[i]
4179                i += 1
4180                taken[video_id][clip_id][int(start) : int(end)] = 1
4181            factor *= 0.9
4182            if factor < 0.05:
4183                warnings.warn(f"Could not find enough frames!")
4184                break
4185        active_learning_intervals = {video_id: [] for video_id in video_intervals}
4186        for video_id in taken:
4187            for clip_id in taken[video_id]:
4188                if video_id in annotated and clip_id in annotated[video_id]:
4189                    for start, end in annotated[video_id][clip_id]:
4190                        taken[video_id][clip_id][int(start) : int(end)] = 0
4191                if (taken[video_id][clip_id] == 1).sum() == 0:
4192                    continue
4193                indices = np.where(taken[video_id][clip_id].numpy())[0]
4194                boundaries = self._get_intervals(indices)
4195                active_learning_intervals[video_id] += [
4196                    [start, end, clip_id] for start, end in boundaries
4197                ]
4198        return active_learning_intervals
4199
4200    def _update_mask(
4201        self,
4202        task: TaskDispatcher,
4203        mask_name: str,
4204        score_tensors: Dict,
4205        n_frames: int,
4206        min_length: int,
4207    ) -> None:
4208        """
4209        Update the real_lens with intervals with the highest score of total length `n_frames`
4210
4211        Parameters
4212        ----------
4213        mask_name : str
4214            the name of the real_lens
4215        score_tensors : dict
4216            a dictionary where keys are clip ids and values are framewise score tensors
4217        n_frames : int
4218            the number of frames to "annotate"
4219        min_length : int
4220            the minimum length of the annotated intervals
4221        """
4222
4223        mask = self._load_mask(mask_name)
4224        video_intervals, _ = task.dataset("train").get_intervals()
4225        masked = {
4226            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
4227        }
4228        total_masked = 0
4229        total_all = 0
4230        for video_id in video_intervals:
4231            for clip_id in video_intervals[video_id]:
4232                masked[video_id][clip_id] = torch.zeros(
4233                    task.dataset("train").get_len(video_id, clip_id)
4234                )
4235                if (
4236                    video_id in mask["unannotated"]
4237                    and clip_id in mask["unannotated"][video_id]
4238                ):
4239                    for start, end in mask["unannotated"][video_id][clip_id]:
4240                        score_tensors[video_id][clip_id][:, start:end] = -10
4241                        masked[video_id][clip_id][int(start) : int(end)] = 1
4242                if (
4243                    video_id in mask["val_intervals"]
4244                    and clip_id in mask["val_intervals"][video_id]
4245                ):
4246                    for start, end in mask["val_intervals"][video_id][clip_id]:
4247                        score_tensors[video_id][clip_id][:, start:end] = -10
4248                        masked[video_id][clip_id][int(start) : int(end)] = 1
4249                total_all += torch.sum(masked[video_id][clip_id] == 0)
4250                if video_id in mask["masked"] and clip_id in mask["masked"][video_id]:
4251                    # print(f'{real_lens["masked"][video_id][clip_id]=}')
4252                    for start, end in mask["masked"][video_id][clip_id]:
4253                        masked[video_id][clip_id][int(start) : int(end)] = 1
4254                        total_masked += end - start
4255        old_n_frames = sum(
4256            [(vv == 0).sum() for v in masked.values() for vv in v.values()]
4257        )
4258        n_frames = old_n_frames + n_frames
4259        factor = 1
4260        while (
4261            sum([(vv == 0).sum() for v in masked.values() for vv in v.values()])
4262            < n_frames
4263        ):
4264            threshold = float(
4265                torch.mean(
4266                    torch.tensor(
4267                        [
4268                            torch.mean(
4269                                torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
4270                            )
4271                            for x in score_tensors.values()
4272                        ]
4273                    )
4274                )
4275            )
4276            threshold = threshold * factor
4277            intervals = []
4278            interval_scores = []
4279            key1 = list(score_tensors.keys())[0]
4280            key2 = list(score_tensors[key1].keys())[0]
4281            num_scores = score_tensors[key1][key2].shape[0]
4282            for i in range(num_scores):
4283                v_dict = task.dataset("train").find_valleys(
4284                    predicted=score_tensors,
4285                    threshold=threshold,
4286                    min_frames=min_length,
4287                    main_class=i,
4288                    low=False,
4289                )
4290                for v_id, interval_list in v_dict.items():
4291                    intervals += [x + [v_id] for x in interval_list]
4292                    interval_scores += [
4293                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
4294                        for start, end, clip_id in interval_list
4295                    ]
4296            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
4297            i = 0
4298            while sum(
4299                [(vv == 0).sum() for v in masked.values() for vv in v.values()]
4300            ) < n_frames and i < len(intervals):
4301                start, end, clip_id, video_id = intervals[i]
4302                i += 1
4303                masked[video_id][clip_id][int(start) : int(end)] = 0
4304            factor *= 0.9
4305            if factor < 0.05:
4306                warnings.warn(f"Could not find enough frames!")
4307                break
4308        mask["masked"] = {video_id: {} for video_id in video_intervals}
4309        total_masked_new = 0
4310        for video_id in masked:
4311            for clip_id in masked[video_id]:
4312                if (
4313                    video_id in mask["unannotated"]
4314                    and clip_id in mask["unannotated"][video_id]
4315                ):
4316                    for start, end in mask["unannotated"][video_id][clip_id]:
4317                        masked[video_id][clip_id][int(start) : int(end)] = 0
4318                if (
4319                    video_id in mask["val_intervals"]
4320                    and clip_id in mask["val_intervals"][video_id]
4321                ):
4322                    for start, end in mask["val_intervals"][video_id][clip_id]:
4323                        masked[video_id][clip_id][int(start) : int(end)] = 0
4324                indices = np.where(masked[video_id][clip_id].numpy())[0]
4325                mask["masked"][video_id][clip_id] = self._get_intervals(indices)
4326        for video_id in mask["masked"]:
4327            for clip_id in mask["masked"][video_id]:
4328                for start, end in mask["masked"][video_id][clip_id]:
4329                    total_masked_new += end - start
4330        self._save_mask(mask, mask_name)
4331        with open(
4332            os.path.join(self.project_path, "results", f"{mask_name}.txt"), "a"
4333        ) as f:
4334            f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n")
4335        print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}")
4336
4337    def plot_confusion_matrix(
4338        self,
4339        episode_name: str,
4340        load_epoch: int = None,
4341        parameters_update: Dict = None,
4342        type: str = "recall",
4343        mode: str = "val",
4344        remove_saved_features: bool = False,
4345    ) -> Tuple[ndarray, Iterable]:
4346        """
4347        Make a confusion matrix plot and return the data
4348
4349        If the annotation is non-exclusive, only false positive labels are considered.
4350
4351        Parameters
4352        ----------
4353        episode_name : str
4354            the name of the episode to load
4355        load_epoch : int, optional
4356            the index of the epoch to load (by default the last one is loaded)
4357        parameters_update : dict, optional
4358            a dictionary of parameter updates (only for "data" and "general" categories)
4359        mode : {'val', 'all', 'test', 'train'}
4360            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
4361        type : {"recall", "precision"}
4362            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
4363            into account, and if `type` is `"precision"`, only false negatives
4364        remove_saved_features : bool, default False
4365            if `True`, the dataset that is used for computation is then deleted
4366
4367        Returns
4368        -------
4369        confusion_matrix : np.ndarray
4370            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
4371            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
4372            `N_i` is the number of frames that have the i-th label in the ground truth
4373        classes : list
4374            a list of labels
4375        """
4376
4377        task, parameters, mode = self._make_task_prediction(
4378            "_",
4379            load_episode=episode_name,
4380            load_epoch=load_epoch,
4381            parameters_update=parameters_update,
4382            mode=mode,
4383        )
4384        dataset = task.dataset(mode)
4385        prediction = task.predict(dataset, raw_output=True)
4386        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
4387        if remove_saved_features:
4388            self._remove_stores(parameters)
4389        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
4390        ax.imshow(confusion_matrix)
4391        # Show all ticks and label them with the respective list entries
4392        ax.set_xticks(np.arange(len(classes)))
4393        ax.set_xticklabels(classes)
4394        ax.set_yticks(np.arange(len(classes)))
4395        ax.set_yticklabels(classes)
4396        # Rotate the tick labels and set their alignment.
4397        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
4398        # Loop over data dimensions and create text annotations.
4399        for i in range(len(classes)):
4400            for j in range(len(classes)):
4401                ax.text(
4402                    j,
4403                    i,
4404                    np.round(confusion_matrix[i, j], 2),
4405                    ha="center",
4406                    va="center",
4407                    color="w",
4408                )
4409        if type is not None:
4410            ax.set_title(f"{type} {episode_name}")
4411        else:
4412            ax.set_title(episode_name)
4413        fig.tight_layout()
4414        plt.show()
4415        return confusion_matrix, classes
4416
4417    def plot_predictions(
4418        self,
4419        episode_name: str,
4420        load_epoch: int = None,
4421        parameters_update: Dict = None,
4422        add_legend: bool = True,
4423        ground_truth: bool = True,
4424        colormap: str = "viridis",
4425        hide_axes: bool = False,
4426        min_classes: int = 1,
4427        width: float = 10,
4428        whole_video: bool = False,
4429        transparent: bool = False,
4430        drop_classes: Set = None,
4431        search_classes: Set = None,
4432        num_plots: int = 1,
4433        remove_saved_features: bool = False,
4434        smooth_interval_prediction: int = 0,
4435        data_path: str = None,
4436        file_paths: Set = None,
4437        mode: str = "val",
4438        behavior_name: str = None,
4439    ) -> None:
4440        """
4441        Visualize random predictions
4442
4443        Parameters
4444        ----------
4445        episode_name : str
4446            the name of the episode to load
4447        load_epoch : int, optional
4448            the epoch to load (by default last)
4449        parameters_update : dict, optional
4450            parameter update dictionary
4451        add_legend : bool, default True
4452            if True, legend will be added to the plot
4453        ground_truth : bool, default True
4454            if True, ground truth will be added to the plot
4455        colormap : str, default 'Accent'
4456            the `matplotlib` colormap to use
4457        hide_axes : bool, default True
4458            if `True`, the axes will be hidden on the plot
4459        min_classes : int, default 1
4460            the minimum number of classes in a displayed interval
4461        width : float, default 10
4462            the width of the plot
4463        whole_video : bool, default False
4464            if `True`, whole videos are plotted instead of segments
4465        transparent : bool, default False
4466            if `True`, the background on the plot is transparent
4467        drop_classes : set, optional
4468            a set of class names to not be displayed
4469        search_classes : set, optional
4470            if given, only intervals where at least one of the classes is in ground truth will be shown
4471        num_plots : int, default 1
4472            the number of plots to make
4473        remove_saved_features : bool, default False
4474            if `True`, the dataset will be deleted after computation
4475        smooth_interval_prediction : int, default 0
4476            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
4477        data_path : str, optional
4478            the data path to run the prediction for
4479        mode : {'all', 'test', 'val', 'train'}
4480            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
4481        file_paths : set, optional
4482            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
4483            for
4484        behavior_name : str, optional
4485            for non-exclusive classificaton datasets, choose which behavior to visualize (by default first in list)
4486        """
4487
4488        other_path = os.path.join(self.project_path, "results", "other")
4489        task, parameters, mode = self._make_task_prediction(
4490            "_",
4491            load_episode=episode_name,
4492            parameters_update=parameters_update,
4493            load_epoch=load_epoch,
4494            data_path=data_path,
4495            file_paths=file_paths,
4496            mode=mode,
4497        )
4498        if not os.path.exists(other_path):
4499            os.mkdir(other_path)
4500        for i in range(num_plots):
4501            task.visualize_results(
4502                save_path=os.path.join(
4503                    other_path, f"{episode_name}_prediction_{i}.jpg"
4504                ),
4505                add_legend=add_legend,
4506                ground_truth=ground_truth,
4507                colormap=colormap,
4508                hide_axes=hide_axes,
4509                min_classes=min_classes,
4510                whole_video=whole_video,
4511                transparent=transparent,
4512                dataset=mode,
4513                drop_classes=drop_classes,
4514                search_classes=search_classes,
4515                width=width,
4516                smooth_interval_prediction=smooth_interval_prediction,
4517                behavior_name=behavior_name,
4518            )
4519        if remove_saved_features:
4520            self._remove_stores(parameters)
4521
4522    def create_metadata_backup(self) -> None:
4523        """
4524        Create a copy of the meta files
4525        """
4526
4527        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
4528        meta_path = os.path.join(self.project_path, "meta")
4529        if os.path.exists(meta_copy_path):
4530            shutil.rmtree(meta_copy_path)
4531        os.mkdir(meta_copy_path)
4532        for file in os.listdir(meta_path):
4533            if file == "backup":
4534                continue
4535            shutil.copy(
4536                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
4537            )
4538
4539    def load_metadata_backup(self) -> None:
4540        """
4541        Load from previously created meta data backup (in case of corruption)
4542        """
4543
4544        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
4545        meta_path = os.path.join(self.project_path, "meta")
4546        for file in os.listdir(meta_copy_path):
4547            shutil.copy(
4548                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
4549            )
4550
4551    def get_behavior_dictionary(self, episode_name: str) -> Dict:
4552        """
4553        Get the behavior dictionary for an episode
4554
4555        Parameters
4556        ----------
4557        episode_name : str
4558            the name of the episode
4559
4560        Returns
4561        -------
4562        behaviors_dictionary : dict
4563            a dictionary where keys are label indices and values are label names
4564        """
4565
4566        run = self._episodes().get_runs(episode_name)[0]
4567        return self._episode(run).get_behaviors_dict()
4568
4569    def import_episodes(
4570        self,
4571        episodes_directory: str,
4572        name_map: Dict = None,
4573        repeat_policy: str = "error",
4574    ) -> None:
4575        """
4576        Import episodes exported with `Project.export_episodes`
4577
4578        Parameters
4579        ----------
4580        episodes_directory : str
4581            the path to the exported episodes directory
4582        name_map : dict
4583            a name change dictionary for the episodes: keys are old names, values are new names
4584        """
4585
4586        if name_map is None:
4587            name_map = {}
4588        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
4589        to_remove = []
4590        import_string = "Imported episodes: "
4591        for episode_name in episodes.index:
4592            if episode_name in name_map:
4593                import_string += f"{episode_name} "
4594                episode_name = name_map[episode_name]
4595                import_string += f"({episode_name}), "
4596            else:
4597                import_string += f"{episode_name}, "
4598            try:
4599                self._check_episode_validity(episode_name, allow_doublecolon=True)
4600            except ValueError as e:
4601                if str(e).endswith("is already taken!"):
4602                    if repeat_policy == "skip":
4603                        to_remove.append(episode_name)
4604                    elif repeat_policy == "force":
4605                        self.remove_episode(episode_name)
4606                    elif repeat_policy == "error":
4607                        raise ValueError(
4608                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
4609                        )
4610                    else:
4611                        raise ValueError(
4612                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' ans 'error']"
4613                        )
4614        episodes = episodes.drop(index=to_remove)
4615        self._episodes().update(
4616            episodes,
4617            name_map=name_map,
4618            force=(repeat_policy == "force"),
4619            data_path=self.data_path,
4620            annotation_path=self.annotation_path,
4621        )
4622        for episode_name in episodes.index:
4623            if episode_name in name_map:
4624                new_episode_name = name_map[episode_name]
4625            else:
4626                new_episode_name = episode_name
4627            model_dir = os.path.join(
4628                self.project_path, "results", "model", new_episode_name
4629            )
4630            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
4631            if os.path.exists(model_dir):
4632                shutil.rmtree(model_dir)
4633            os.mkdir(model_dir)
4634            for file in os.listdir(old_model_dir):
4635                shutil.copyfile(
4636                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
4637                )
4638            log_file = os.path.join(
4639                self.project_path, "results", "logs", f"{new_episode_name}.txt"
4640            )
4641            old_log_file = os.path.join(
4642                episodes_directory, "logs", f"{episode_name}.txt"
4643            )
4644            shutil.copyfile(old_log_file, log_file)
4645        print(import_string)
4646        print("\n")
4647
4648    def export_episodes(
4649        self, episode_names: List, output_directory: str, name: str = None
4650    ) -> None:
4651        """
4652        Save selected episodes as a file that can be imported into another project with `Project.import_episodes`
4653
4654        Parameters
4655        ----------
4656        episode_names : list
4657            a list of string episode names
4658        output_directory : str
4659            the path to the directory where the episodes will be saved
4660        name : str, optional
4661            the name of the episodes directory (by default `exported_episodes`)
4662        """
4663
4664        if name is None:
4665            name = "exported_episodes"
4666        if os.path.exists(
4667            os.path.join(output_directory, name + ".zip")
4668        ) or os.path.exists(os.path.join(output_directory, name)):
4669            i = 1
4670            while os.path.exists(
4671                os.path.join(output_directory, name + f"_{i}.zip")
4672            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
4673                i += 1
4674            name = name + f"_{i}"
4675        dest_dir = os.path.join(output_directory, name)
4676        os.mkdir(dest_dir)
4677        os.mkdir(os.path.join(dest_dir, "model"))
4678        os.mkdir(os.path.join(dest_dir, "logs"))
4679        runs = []
4680        for episode in episode_names:
4681            runs += self._episodes().get_runs(episode)
4682        for run in runs:
4683            shutil.copytree(
4684                os.path.join(self.project_path, "results", "model", run),
4685                os.path.join(dest_dir, "model", run),
4686            )
4687            shutil.copyfile(
4688                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
4689                os.path.join(dest_dir, "logs", f"{run}.txt"),
4690            )
4691        data = self._episodes().get_subset(runs)
4692        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))
4693
4694    def get_results_table(
4695        self,
4696        episode_names: List,
4697        metrics: List = None,
4698        include_std: bool = False,
4699        classes: List = None,
4700    ):
4701        """
4702        Genererate a `pandas` dataframe with a summary of episode results
4703
4704        Parameters
4705        ----------
4706        episode_names : list
4707            a list of names of episodes to include
4708        metrics : list, optional
4709            a list of metric names to include
4710        include_std : bool, default False
4711            if `True`, for episodes with multiple runs the mean and standard deviation will be displayed;
4712            otherwise only mean
4713        classes : list, optional
4714            a list of names of classes to include (by default all are included)
4715
4716        Returns
4717        -------
4718        results : pd.DataFrame
4719            a table with the results
4720        """
4721
4722        run_names = []
4723        for episode in episode_names:
4724            run_names += self._episodes().get_runs(episode)
4725        episodes = self.list_episodes(run_names, print_results=False)
4726        metric_columns = [x for x in episodes.columns if x[0] == "results"]
4727        results_df = pd.DataFrame()
4728        if metrics is not None:
4729            metric_columns = [
4730                x for x in metric_columns if x[1].split("_")[0] in metrics
4731            ]
4732        for episode in episode_names:
4733            results = []
4734            metric_set = set()
4735            for run in self._episodes().get_runs(episode):
4736                beh_dict = self.get_behavior_dictionary(run)
4737                res_dict = defaultdict(lambda: {})
4738                for column in metric_columns:
4739                    if np.isnan(episodes.loc[run, column]):
4740                        continue
4741                    split = column[1].split("_")
4742                    if split[-1].isnumeric():
4743                        beh_ind = int(split[-1])
4744                        metric_name = "_".join(split[:-1])
4745                        beh = beh_dict[beh_ind]
4746                    else:
4747                        beh = "average"
4748                        metric_name = column[1]
4749                    res_dict[beh][metric_name] = episodes.loc[run, column]
4750                    metric_set.add(metric_name)
4751                if "average" not in res_dict:
4752                    res_dict["average"] = {}
4753                for metric in metric_set:
4754                    if metric not in res_dict["average"]:
4755                        arr = [
4756                            res_dict[beh][metric]
4757                            for beh in res_dict
4758                            if metric in res_dict[beh]
4759                        ]
4760                        res_dict["average"][metric] = np.mean(arr)
4761                results.append(res_dict)
4762            episode_results = {}
4763            for metric in metric_set:
4764                for beh in results[0].keys():
4765                    if classes is not None and beh not in classes:
4766                        continue
4767                    arr = []
4768                    for res_dict in results:
4769                        if metric in res_dict[beh]:
4770                            arr.append(res_dict[beh][metric])
4771                    if len(arr) > 0:
4772                        if include_std:
4773                            episode_results[
4774                                (beh, f"{episode} {metric} mean")
4775                            ] = np.mean(arr)
4776                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
4777                                arr
4778                            )
4779                        else:
4780                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
4781            for key, value in episode_results.items():
4782                results_df.loc[key[0], key[1]] = value
4783        print(f"RESULTS:")
4784        print(results_df)
4785        print("\n")
4786        return results_df
4787
4788    def episode_exists(self, episode_name: str) -> bool:
4789        """
4790        Check if an episode already exists
4791
4792        Parameters
4793        ----------
4794        episode_name : str
4795            the episode name
4796
4797        Returns
4798        -------
4799        exists : bool
4800            `True` if the episode exists
4801        """
4802
4803        return self._episodes().check_name_validity(episode_name)
4804
4805    def search_exists(self, search_name: str) -> bool:
4806        """
4807        Check if a search already exists
4808
4809        Parameters
4810        ----------
4811        search_name : str
4812            the search name
4813
4814        Returns
4815        -------
4816        exists : bool
4817            `True` if the search exists
4818        """
4819
4820        return self._searches().check_name_validity(search_name)
4821
4822    def prediction_exists(self, prediction_name: str) -> bool:
4823        """
4824        Check if a prediction already exists
4825
4826        Parameters
4827        ----------
4828        prediction_name : str
4829            the prediction name
4830
4831        Returns
4832        -------
4833        exists : bool
4834            `True` if the prediction exists
4835        """
4836
4837        return self._predictions().check_name_validity(prediction_name)
4838
4839    @staticmethod
4840    def project_name_available(projects_path: str, project_name: str):
4841        if projects_path is None:
4842            projects_path = os.path.join(str(Path.home()), "DLC2Action")
4843        return not os.path.exists(os.path.join(projects_path, project_name))
4844
4845    def _update_episode_metrics(self, episode_name: str, metrics: Dict):
4846        """
4847        Update meta data with evaluation results
4848        """
4849
4850        self._episodes().update_episode_metrics(episode_name, metrics)
4851
4852    def rename_episode(self, episode_name: str, new_episode_name: str):
4853        shutil.move(
4854            os.path.join(self.project_path, "results", "model", episode_name),
4855            os.path.join(self.project_path, "results", "model", new_episode_name),
4856        )
4857        shutil.move(
4858            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
4859            os.path.join(
4860                self.project_path, "results", "logs", f"{new_episode_name}.txt"
4861            ),
4862        )
4863        self._episodes().rename_episode(episode_name, new_episode_name)
4864
4865
4866class _Runner:
4867    """
4868    A helper class for running hyperparameter searches
4869    """
4870
4871    def __init__(
4872        self,
4873        search_name,
4874        search_space: Dict,
4875        load_episode: str,
4876        load_epoch: int,
4877        metric: str,
4878        average: int,
4879        task: Union[TaskDispatcher, None],
4880        remove_saved_features: bool,
4881        project: Project,
4882    ):
4883        """
4884        Parameters
4885        ----------
4886        search_space : dict
4887            a dictionary representing the search space; of this general structure:
4888            {'group/param_name': ('float/int/float_log/int_log', start, end),
4889            'group/param_name': ('categorical', [choices])}, e.g.
4890            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
4891            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
4892        load_episode : str
4893            the name of the episode to load the model from
4894        load_epoch : int
4895            the epoch to load the model from (if not provided, the last checkpoint is used)
4896        metric : str
4897            the metric to maximize/minimize (see direction)
4898        average : int
4899            the number of epochs to average the metric; if 0, the last value is taken
4900        remove_saved_features : bool
4901            if `True`, the old datasets will be deleted when data parameters change
4902        project : Project
4903            the parent `Project` instance
4904        """
4905
4906        self.search_space = search_space
4907        self.load_episode = load_episode
4908        self.load_epoch = load_epoch
4909        self.metric = metric
4910        self.average = average
4911        self.feature_save_path = None
4912        self.remove_saved_featuress = remove_saved_features
4913        self.save_stores = project._save_stores
4914        self.remove_datasets = project.remove_saved_features
4915        self.task = task
4916        self.search_name = search_name
4917        self.update = project._update
4918        self.remove_episode = project.remove_episode
4919        self.fill = project._fill
4920
4921    def clean(self):
4922        """
4923        Remove datasets if needed
4924        """
4925
4926        if self.remove_saved_featuress:
4927            self.remove_datasets([os.path.basename(self.feature_save_path)])
4928
4929    def run(self, trial, parameters):
4930        """
4931        Make a trial run
4932        """
4933
4934        params = deepcopy(parameters)
4935        param_update = defaultdict(
4936            lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: {})))
4937        )
4938        for full_name, space in self.search_space.items():
4939            group, param_name = full_name.split("/")[0], "/".join(
4940                full_name.split("/")[1:]
4941            )
4942            log = space[0][-3:] == "log"
4943            if space[0].startswith("int"):
4944                value = trial.suggest_int(full_name, space[1], space[2], log=log)
4945            elif space[0].startswith("float"):
4946                value = trial.suggest_float(full_name, space[1], space[2], log=log)
4947            elif space[0] == "categorical":
4948                value = trial.suggest_categorical(full_name, space[1])
4949            else:
4950                raise ValueError(
4951                    "The search space has to be formatted as either "
4952                    '("float"/"int"/"float_log"/"int_log", start, end) '
4953                    f'or ("categorical", [choices]); got {space} for {group}/{param_name}'
4954                )
4955            if len(param_name.split("/")) == 1:
4956                param_update[group][param_name] = value
4957            else:
4958                pars = param_name.split("/")
4959                pars = [int(x) if x.isnumeric() else x for x in pars]
4960                if len(pars) == 2:
4961                    param_update[group][pars[0]][pars[1]] = value
4962                elif len(pars) == 3:
4963                    param_update[group][pars[0]][pars[1]][pars[2]] = value
4964                elif len(pars) == 4:
4965                    param_update[group][pars[0]][pars[1]][pars[2]][pars[3]] = value
4966        params = self.update(params, param_update)
4967        self.remove_episode(f"_{self.search_name}")
4968        params = self.fill(
4969            params,
4970            f"_{self.search_name}",
4971            self.load_episode,
4972            load_epoch=self.load_epoch,
4973            only_load_model=True,
4974        )
4975        if self.feature_save_path != params["data"]["feature_save_path"]:
4976            if self.feature_save_path is not None:
4977                self.clean()
4978            self.feature_save_path = params["data"]["feature_save_path"]
4979        self.save_stores(params)
4980        if self.task is None:
4981            self.task = TaskDispatcher(deepcopy(params))
4982        else:
4983            self.task.update_task(params)
4984
4985        _, metrics_log = self.task.train(trial, self.metric)
4986        metric_values = metrics_log["val"][self.metric]
4987        if self.average > 0:
4988            value = np.mean(sorted(metric_values)[-self.average :])
4989        else:
4990            value = metric_values[-1]
4991        return value
class Project:
  51class Project:
  52    """
  53    A class to create and maintain the project files + keep track of experiments
  54    """
  55
  56    def __init__(
  57        self,
  58        name: str,
  59        data_type: str = None,
  60        annotation_type: str = "none",
  61        projects_path: str = None,
  62        data_path: Union[str, List] = None,
  63        annotation_path: Union[str, List] = None,
  64        copy: bool = False,
  65    ) -> None:
  66        """
  67        Parameters
  68        ----------
  69        name : str
  70            name of the project
  71        data_type : str, optional
  72            data type (run Project.data_types() to see available options; has to be provided if the project is being
  73            created)
  74        annotation_type : str, default 'none'
  75            annotation type (run Project.annotation_types() to see available options)
  76        projects_path : str, optional
  77            path to the projects folder (is filled with ~/DLC2Action by default)
  78        data_path : str, optional
  79            path to the folder containing input files for the project (has to be provided if the project is being
  80            created)
  81        annotation_path : str, optional
  82            path to the folder containing annotation files for the project
  83        copy : bool, default False
  84            if True, the files from annotation_path and data_path will be copied to the projects folder;
  85            otherwise they will be moved
  86        """
  87
  88        if projects_path is None:
  89            projects_path = os.path.join(str(Path.home()), "DLC2Action")
  90        if not os.path.exists(projects_path):
  91            os.mkdir(projects_path)
  92        self.project_path = os.path.join(projects_path, name)
  93        self.name = name
  94        self.data_type = data_type
  95        self.annotation_type = annotation_type
  96        self.data_path = data_path
  97        self.annotation_path = annotation_path
  98        if not os.path.exists(self.project_path):
  99            if data_type is None:
 100                raise ValueError(
 101                    "The data_type parameter is necessary when creating a new project!"
 102                )
 103            self._initialize_project(
 104                data_type, annotation_type, data_path, annotation_path, copy
 105            )
 106        else:
 107            self.annotation_type, self.data_type = self._read_types()
 108            if data_type != self.data_type and data_type is not None:
 109                raise ValueError(
 110                    f"The project has already been initialized with data_type={self.data_type}!"
 111                )
 112            if annotation_type != self.annotation_type and annotation_type != "none":
 113                raise ValueError(
 114                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
 115                )
 116            self.annotation_path, data_path = self._read_paths()
 117            if self.data_path is None:
 118                self.data_path = data_path
 119            # if data_path != self.data_path and data_path is not None:
 120            #     raise ValueError(
 121            #         f"The project has already been initialized with data_path={self.data_path}!"
 122            #     )
 123            if annotation_path != self.annotation_path and annotation_path is not None:
 124                raise ValueError(
 125                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
 126                )
 127        self._update_configs()
 128
 129    def _aggregate_predictions(
 130        self,
 131        prediction_name: str,
 132        episode_names: List,
 133        load_epochs: List = None,
 134        parameters_update: Dict = None,
 135        data_path: str = None,
 136        file_paths: Set = None,
 137        mode: str = "all",
 138        augment_n: int = 0,
 139        evaluate: bool = False,
 140        task: TaskDispatcher = None,
 141        embedding: bool = False,
 142    ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]:
 143        """
 144        Generate a prediction
 145        """
 146
 147        if load_epochs is None:
 148            load_epochs = [None for _ in episode_names]
 149        prediction = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0)))
 150        cnt = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0)))
 151        behs = set(self.get_behavior_dictionary(episode_names[0]).values())
 152        if not all(
 153            [
 154                set(self.get_behavior_dictionary(x).values()) == behs
 155                for x in episode_names
 156            ]
 157        ):
 158            raise ValueError(f"The behavior sets are different in {episode_names}")
 159        behaviors = set()
 160        for i, episode_name in enumerate(episode_names):
 161            task, parameters, data_mode, new_pred, _ = self._make_prediction(
 162                prediction_name,
 163                episode_names=[episode_name],
 164                load_epochs=[load_epochs[i]],
 165                parameters_update=parameters_update,
 166                data_path=data_path,
 167                file_paths=file_paths,
 168                mode=mode,
 169                augment_n=augment_n,
 170                evaluate=evaluate,
 171                task=task,
 172                embedding=embedding,
 173            )
 174            new_pred = task.dataset(data_mode).generate_full_length_prediction(new_pred)
 175            beh_dict = task.behaviors_dict()
 176            for video_id, video_values in new_pred.items():
 177                for clip_id, clip_prediction in video_values.items():
 178                    for beh_i in range(clip_prediction.shape[0]):
 179                        prediction[video_id][clip_id][
 180                            beh_dict[beh_i]
 181                        ] += clip_prediction[beh_i, :].unsqueeze(0)
 182                        cnt[video_id][clip_id][beh_dict[beh_i]] += 1
 183                        behaviors.add(beh_dict[beh_i])
 184        output = defaultdict(lambda: {})
 185        # behaviors = sorted(behaviors)
 186        behavior_indices = sorted(
 187            [x for x in task.behaviors_dict().keys() if x != -100]
 188        )
 189        behaviors = [task.behaviors_dict()[key] for key in behavior_indices]
 190        for video_id, video_values in prediction.items():
 191            for clip_id, clip_values in video_values.items():
 192                pred = torch.cat(
 193                    [
 194                        clip_values[beh] / cnt[video_id][clip_id][beh]
 195                        for beh in behaviors
 196                    ],
 197                    0,
 198                )
 199                output[video_id][clip_id] = pred
 200        return task, parameters, data_mode, dict(output), None
 201
 202    def _make_prediction(
 203        self,
 204        prediction_name: str,
 205        episode_names: List,
 206        load_epochs: List = None,
 207        parameters_update: Dict = None,
 208        data_path: str = None,
 209        file_paths: Set = None,
 210        mode: str = "all",
 211        augment_n: int = 0,
 212        evaluate: bool = False,
 213        task: TaskDispatcher = None,
 214        embedding: bool = False,
 215    ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]:
 216        """
 217        Generate a prediction
 218        """
 219
 220        names = []
 221        epochs = []
 222        if load_epochs is None:
 223            load_epochs = [None for _ in episode_names]
 224        if len(load_epochs) != len(episode_names):
 225            raise ValueError(
 226                f"The length of load_epochs and the length of episode_names should be the same!"
 227            )
 228        for i, episode_name in enumerate(episode_names):
 229            names += self._episodes().get_runs(episode_name)
 230            epochs.append(load_epochs[i])
 231        if len(names) == 0:
 232            warnings.warn(f"None of the episodes {episode_names} exist!")
 233            names = [None]
 234        episodes = self._episodes()
 235        lengths = [
 236            episodes.load_parameters(name)["general"]["len_segment"] for name in names
 237        ]
 238        overlaps = [
 239            episodes.load_parameters(name)["general"]["overlap"] for name in names
 240        ]
 241        if not all([x == lengths[0] for x in lengths]):
 242            raise ValueError(f"Episodes {episode_names} have different segment lengths")
 243        if not all([x == overlaps[0] for x in overlaps]):
 244            raise ValueError(f"Episodes {episode_names} have different overlaps")
 245        load_epochs = epochs
 246        prediction = None
 247        decision_thresholds = None
 248        time_total = 0
 249        behavior_dicts = [
 250            self.get_behavior_dictionary(episode_name) for episode_name in names
 251        ]
 252        if not all(
 253            [
 254                set(d.values()) == set(behavior_dicts[0].values())
 255                for d in behavior_dicts[1:]
 256            ]
 257        ):
 258            raise ValueError(
 259                f"Episodes {episode_names} have different sets of behaviors!"
 260            )
 261        behavior_indices = [x for x in behavior_dicts[0].keys() if x != -100]
 262        behaviors = [behavior_dicts[0][i] for i in behavior_indices]
 263        cnt = defaultdict(lambda: 0)
 264        behavior_probs = defaultdict(lambda: 0)
 265        for episode_name, load_epoch, behavior_dict in zip(
 266            names, load_epochs, behavior_dicts
 267        ):
 268            print(f"episode {episode_name}")
 269            task, parameters, data_mode = self._make_task_prediction(
 270                prediction_name=prediction_name,
 271                load_episode=episode_name,
 272                parameters_update=parameters_update,
 273                load_epoch=load_epoch,
 274                data_path=data_path,
 275                mode=mode,
 276                file_paths=file_paths,
 277                task=task,
 278                decision_thresholds=decision_thresholds,
 279            )
 280            behavior_indices_cur = [x for x in behavior_dict.keys() if x != -100]
 281            behaviors_cur = [behavior_dict[i] for i in behavior_indices_cur]
 282            # data_mode = "train" if mode == "all" else mode
 283            time_start = time.time()
 284            new_pred = task.predict(
 285                data_mode,
 286                raw_output=True,
 287                apply_primary_function=True,
 288                augment_n=augment_n,
 289                embedding=embedding,
 290            )
 291            for j, beh in enumerate(behaviors_cur):
 292                cnt[beh] += 1
 293                behavior_probs[beh] += new_pred[:, j, :].unsqueeze(1)
 294            # indices = [
 295            #     behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1])
 296            # ]
 297            # new_pred = new_pred[:, indices, :]
 298            time_end = time.time()
 299            time_total += time_end - time_start
 300            if evaluate:
 301                _, metrics = task.evaluate_prediction(new_pred, data=data_mode)
 302                if mode == "val":
 303                    self._update_episode_metrics(episode_name, metrics)
 304            # if prediction is None:
 305            #     prediction = new_pred
 306            # else:
 307            #     prediction += new_pred
 308            print("\n")
 309        prediction = torch.cat([behavior_probs[beh] / cnt[beh] for beh in behaviors], 1)
 310        hours = int(time_total // 3600)
 311        time_total -= hours * 3600
 312        minutes = int(time_total // 60)
 313        time_total -= minutes * 60
 314        seconds = int(time_total)
 315        inference_time = f"{hours}:{minutes:02}:{seconds:02}"
 316        # prediction /= len(names)
 317        return task, parameters, data_mode, prediction, inference_time
 318
 319    def _make_task_prediction(
 320        self,
 321        prediction_name: str,
 322        load_episode: str = None,
 323        parameters_update: Dict = None,
 324        load_epoch: int = None,
 325        data_path: str = None,
 326        mode: str = "val",
 327        file_paths: Set = None,
 328        decision_thresholds: List = None,
 329        task: TaskDispatcher = None,
 330    ) -> Tuple[TaskDispatcher, Dict, str]:
 331        """
 332        Make a `TaskDispatcher` object that will be used to generate a prediction
 333        """
 334
 335        if parameters_update is None:
 336            parameters_update = {}
 337        parameters_update_second = {}
 338        if mode == "all" or data_path is not None or file_paths is not None:
 339            parameters_update_second["training"] = {
 340                "val_frac": 0,
 341                "test_frac": 0,
 342                "partition_method": "random",
 343                "save_split": False,
 344                "split_path": None,
 345            }
 346            mode = "train"
 347        if decision_thresholds is not None:
 348            if (
 349                len(decision_thresholds)
 350                == self._episode(load_episode).get_num_classes()
 351            ):
 352                parameters_update_second["general"] = {
 353                    "threshold_value": decision_thresholds
 354                }
 355            else:
 356                raise ValueError(
 357                    f"The length of the decision thresholds {decision_thresholds} "
 358                    f"must be equal to the length of the behaviors dictionary "
 359                    f"{self._episode(load_episode).get_behaviors_dict()}"
 360                )
 361        data_param_update = {}
 362        if data_path is not None:
 363            data_param_update = {"data_path": data_path}
 364        if file_paths is not None:
 365            data_param_update = {"data_path": None, "file_paths": file_paths}
 366        parameters_update = self._update(parameters_update, {"data": data_param_update})
 367        if data_path is not None or file_paths is not None:
 368            general_update = {
 369                "annotation_type": "none",
 370                "only_load_annotated": False,
 371            }
 372        else:
 373            general_update = {}
 374        parameters_update = self._update(parameters_update, {"general": general_update})
 375        task, parameters = self._make_task(
 376            episode_name=prediction_name,
 377            load_episode=load_episode,
 378            parameters_update=parameters_update,
 379            parameters_update_second=parameters_update_second,
 380            load_epoch=load_epoch,
 381            purpose="prediction",
 382            task=task,
 383            behaviors=self.get_behavior_dictionary(load_episode),
 384        )
 385        # if data_path is not None or file_paths is not None:
 386        #     print('SETTING')
 387        #     task.set_behaviors(self.get_behavior_dictionary(load_episode))
 388        if mode is None:
 389            if task.exists("test"):
 390                mode = "test"
 391            elif task.exists("val"):
 392                mode = "val"
 393            else:
 394                mode = "train"
 395        return task, parameters, mode
 396
 397    def _make_task_training(
 398        self,
 399        episode_name: str,
 400        load_episode: str = None,
 401        parameters_update: Dict = None,
 402        load_epoch: int = None,
 403        load_search: str = None,
 404        load_parameters: list = None,
 405        round_to_binary: list = None,
 406        load_strict: bool = True,
 407        continuing: bool = False,
 408        task: TaskDispatcher = None,
 409        mask_name: str = None,
 410        throwaway: bool = False,
 411    ) -> Tuple[TaskDispatcher, Dict, str]:
 412        """
 413        Make a `TaskDispatcher` object that will be used to generate a prediction
 414        """
 415
 416        if parameters_update is None:
 417            parameters_update = {}
 418        if continuing:
 419            purpose = "continuing"
 420        else:
 421            purpose = "training"
 422        if mask_name is not None:
 423            mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle")
 424        parameters_update_second = {"data": {"real_lens": mask_name}}
 425        if throwaway:
 426            parameters_update = self._update(
 427                parameters_update, {"training": {"normalize": False, "device": "cpu"}}
 428            )
 429        return self._make_task(
 430            episode_name,
 431            load_episode,
 432            parameters_update,
 433            parameters_update_second,
 434            load_epoch,
 435            load_search,
 436            load_parameters,
 437            round_to_binary,
 438            purpose,
 439            task,
 440            load_strict=load_strict,
 441        )
 442
 443    def _make_parameters(
 444        self,
 445        episode_name: str,
 446        load_episode: str = None,
 447        parameters_update: Dict = None,
 448        parameters_update_second: Dict = None,
 449        load_epoch: int = None,
 450        load_search: str = None,
 451        load_parameters: list = None,
 452        round_to_binary: list = None,
 453        purpose: str = "train",
 454        load_strict: bool = True,
 455    ):
 456        """
 457        Construct a parameters dictionary
 458        """
 459
 460        if parameters_update is None:
 461            parameters_update = {}
 462        pars_update = deepcopy(parameters_update)
 463        if parameters_update_second is None:
 464            parameters_update_second = {}
 465        if purpose == "prediction" and "model" in pars_update.keys():
 466            raise ValueError("Cannot change model parameters after training!")
 467        if purpose in ["continuing", "prediction"] and load_episode is not None:
 468            read_parameters = self._read_parameters()
 469            parameters = self._episodes().load_parameters(load_episode)
 470            parameters["metrics"] = self._update(
 471                read_parameters["metrics"], parameters["metrics"]
 472            )
 473            parameters["ssl"] = self._update(
 474                read_parameters["ssl"], parameters.get("ssl", {})
 475            )
 476        else:
 477            parameters = self._read_parameters()
 478        if "model" in pars_update:
 479            model_params = pars_update.pop("model")
 480        else:
 481            model_params = None
 482        if "features" in pars_update:
 483            feat_params = pars_update.pop("features")
 484        else:
 485            feat_params = None
 486        if "augmentations" in pars_update:
 487            aug_params = pars_update.pop("augmentations")
 488        else:
 489            aug_params = None
 490        parameters = self._update(parameters, pars_update)
 491        if pars_update.get("general", {}).get("model_name") is not None:
 492            model_name = parameters["general"]["model_name"]
 493            parameters["model"] = self._open_yaml(
 494                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
 495            )
 496        if pars_update.get("general", {}).get("feature_extraction") is not None:
 497            feat_name = parameters["general"]["feature_extraction"]
 498            parameters["features"] = self._open_yaml(
 499                os.path.join(
 500                    self.project_path, "config", "features", f"{feat_name}.yaml"
 501                )
 502            )
 503            aug_name = options.extractor_to_transformer[
 504                parameters["general"]["feature_extraction"]
 505            ]
 506            parameters["augmentations"] = self._open_yaml(
 507                os.path.join(
 508                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
 509                )
 510            )
 511        if model_params is not None:
 512            parameters["model"] = self._update(parameters["model"], model_params)
 513        if feat_params is not None:
 514            parameters["features"] = self._update(parameters["features"], feat_params)
 515        if aug_params is not None:
 516            parameters["augmentations"] = self._update(
 517                parameters["augmentations"], aug_params
 518            )
 519        if load_search is not None:
 520            parameters = self._update_with_search(
 521                parameters, load_search, load_parameters, round_to_binary
 522            )
 523        parameters = self._fill(
 524            parameters,
 525            episode_name,
 526            load_episode,
 527            load_epoch=load_epoch,
 528            load_strict=load_strict,
 529            only_load_model=(purpose != "continuing"),
 530            continuing=(purpose in ["prediction", "continuing"]),
 531            enforce_split_parameters=(purpose == "prediction"),
 532        )
 533        parameters = self._update(parameters, parameters_update_second)
 534        return parameters
 535
 536    def _make_task(
 537        self,
 538        episode_name: str,
 539        load_episode: str = None,
 540        parameters_update: Dict = None,
 541        parameters_update_second: Dict = None,
 542        load_epoch: int = None,
 543        load_search: str = None,
 544        load_parameters: list = None,
 545        round_to_binary: list = None,
 546        purpose: str = "train",
 547        task: TaskDispatcher = None,
 548        load_strict: bool = True,
 549        behaviors: Dict = None,
 550    ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]:
 551        """
 552        Make a `TaskDispatcher` object
 553
 554        The task parameters are read from the config files and then updated with the
 555        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 556        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 557        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 558        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 559        data parameters are used.
 560
 561        Parameters
 562        ----------
 563        episode_name : str
 564            the name of the episode
 565        load_episode : str, optional
 566            the (previously run) episode name to load the model from
 567        parameters_update : dict, optional
 568            the dictionary used to update the parameters from the config
 569        parameters_update_second : dict, optional
 570            the dictionary used to update the parameters after the automatic fill-out
 571        load_epoch : int, optional
 572            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 573        load_search : str, optional
 574            the hyperparameter search result to load
 575        load_parameters : list, optional
 576            a list of string names of the parameters to load from load_search (if not provided, all parameters
 577            are loaded)
 578        round_to_binary : list, optional
 579            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 580        purpose : {"train", "continuing", "prediction"}
 581            the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing
 582            the training of an interrupted episode, `"prediction"` for generating a prediction)
 583        task : TaskDispatcher, optional
 584            a pre-existing task; if provided, the method will update the task instead of creating a new one
 585            (this might save time, mainly on dataset loading)
 586
 587        Returns
 588        -------
 589        task : TaskDispatcher
 590            the `TaskDispatcher` instance
 591        parameters : dict
 592            the parameters dictionary that describes the task
 593        """
 594
 595        parameters = self._make_parameters(
 596            episode_name,
 597            load_episode,
 598            parameters_update,
 599            parameters_update_second,
 600            load_epoch,
 601            load_search,
 602            load_parameters,
 603            round_to_binary,
 604            purpose,
 605            load_strict=load_strict,
 606        )
 607        if parameters["data"].get("annotation_type", "none") == "none":
 608            parameters = self._update(
 609                parameters, {"data": {"behavior_dictionary": behaviors}}
 610            )
 611        if task is None:
 612            task = TaskDispatcher(parameters)
 613        else:
 614            task.update_task(parameters)
 615        self._save_stores(parameters)
 616        return task, parameters
 617
 618    def run_episode(
 619        self,
 620        episode_name: str,
 621        load_episode: str = None,
 622        parameters_update: Dict = None,
 623        task: TaskDispatcher = None,
 624        load_epoch: int = None,
 625        load_search: str = None,
 626        load_parameters: list = None,
 627        round_to_binary: list = None,
 628        load_strict: bool = True,
 629        n_seeds: int = 1,
 630        force: bool = False,
 631        suppress_name_check: bool = False,
 632        remove_saved_features: bool = False,
 633        mask_name: str = None,
 634        autostop_metric: str = None,
 635        autostop_interval: int = 50,
 636        autostop_threshold: float = 0.001,
 637        loading_bar: bool = False,
 638        trial: Tuple = None,
 639    ) -> TaskDispatcher:
 640        """
 641        Run an episode
 642
 643        The task parameters are read from the config files and then updated with the
 644        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 645        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 646        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 647        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 648        data parameters are used.
 649
 650        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 651        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 652        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 653        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 654
 655        Parameters
 656        ----------
 657        episode_name : str
 658            the episode name
 659        load_episode : str, optional
 660            the (previously run) episode name to load the model from; if the episode has multiple runs,
 661            the new episode will have the same number of runs, each starting with one of the pre-trained models
 662        parameters_update : dict, optional
 663            the dictionary used to update the parameters from the config files
 664        task : TaskDispatcher, optional
 665            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
 666            a new instance)
 667        load_epoch : int, optional
 668            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 669        load_search : str, optional
 670            the hyperparameter search result to load
 671        load_parameters : list, optional
 672            a list of string names of the parameters to load from load_search (if not provided, all parameters
 673            are loaded)
 674        round_to_binary : list, optional
 675            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 676        load_strict : bool, default True
 677            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
 678            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
 679        n_seeds : int, default 1
 680            the number of runs to perform with different random seeds; if `n_seeds > 1`, the episodes will be named
 681            `episode_name::seed_index`, e.g. `test_episode::0` and `test_episode::1`
 682        force : bool, default False
 683            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
 684        suppress_name_check : bool, default False
 685            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 686            why they are usually forbidden)
 687        remove_saved_features : bool, default False
 688            if `True`, the dataset will be deleted after training
 689        mask_name : str, optional
 690            the name of the real_lens to apply
 691        autostop_interval : int, default 50
 692            the number of epochs to average the autostop metric over
 693        autostop_threshold : float, default 0.001
 694            the autostop difference threshold
 695        autostop_metric : str, optional
 696            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 697        """
 698
 699        if type(n_seeds) is not int or n_seeds < 1:
 700            raise ValueError(
 701                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
 702            )
 703        if n_seeds > 1 and mask_name is not None:
 704            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
 705        self._check_episode_validity(
 706            episode_name, allow_doublecolon=suppress_name_check, force=force
 707        )
 708        load_runs = self._episodes().get_runs(load_episode)
 709        if len(load_runs) > 1:
 710            task = self.run_episodes(
 711                episode_names=[
 712                    f'{episode_name}::{run.split("::")[-1]}' for run in load_runs
 713                ],
 714                load_episodes=load_runs,
 715                parameters_updates=[parameters_update for _ in load_runs],
 716                load_epochs=[load_epoch for _ in load_runs],
 717                load_searches=[load_search for _ in load_runs],
 718                load_parameters=[load_parameters for _ in load_runs],
 719                round_to_binary=[round_to_binary for _ in load_runs],
 720                load_strict=[load_strict for _ in load_runs],
 721                suppress_name_check=True,
 722                force=force,
 723                remove_saved_features=False,
 724            )
 725            if remove_saved_features:
 726                self._remove_stores(
 727                    {
 728                        "general": task.general_parameters,
 729                        "data": task.data_parameters,
 730                        "features": task.feature_parameters,
 731                    }
 732                )
 733            if n_seeds > 1:
 734                warnings.warn(
 735                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
 736                )
 737        elif n_seeds > 1:
 738            self.run_episodes(
 739                episode_names=[f"{episode_name}::{i}" for i in range(n_seeds)],
 740                load_episodes=[load_episode for _ in range(n_seeds)],
 741                parameters_updates=[parameters_update for _ in range(n_seeds)],
 742                load_epochs=[load_epoch for _ in range(n_seeds)],
 743                load_searches=[load_search for _ in range(n_seeds)],
 744                load_parameters=[load_parameters for _ in range(n_seeds)],
 745                round_to_binary=[round_to_binary for _ in range(n_seeds)],
 746                load_strict=[load_strict for _ in range(n_seeds)],
 747                suppress_name_check=True,
 748                force=force,
 749                remove_saved_features=remove_saved_features,
 750            )
 751        else:
 752            print(f"TRAINING {episode_name}")
 753            try:
 754                task, parameters = self._make_task_training(
 755                    episode_name,
 756                    load_episode,
 757                    parameters_update,
 758                    load_epoch,
 759                    load_search,
 760                    load_parameters,
 761                    round_to_binary,
 762                    continuing=False,
 763                    task=task,
 764                    mask_name=mask_name,
 765                    load_strict=load_strict,
 766                )
 767                self._save_episode(
 768                    episode_name,
 769                    parameters,
 770                    task.behaviors_dict(),
 771                    norm_stats=task.get_normalization_stats(),
 772                )
 773                time_start = time.time()
 774                if trial is not None:
 775                    trial, metric = trial
 776                else:
 777                    trial, metric = None, None
 778                logs = task.train(
 779                    autostop_metric=autostop_metric,
 780                    autostop_interval=autostop_interval,
 781                    autostop_threshold=autostop_threshold,
 782                    loading_bar=loading_bar,
 783                    trial=trial,
 784                    optimized_metric=metric,
 785                )
 786                time_end = time.time()
 787                time_total = time_end - time_start
 788                hours = int(time_total // 3600)
 789                time_total -= hours * 3600
 790                minutes = int(time_total // 60)
 791                time_total -= minutes * 60
 792                seconds = int(time_total)
 793                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 794                self._update_episode_results(episode_name, logs, training_time)
 795                if remove_saved_features:
 796                    self._remove_stores(parameters)
 797                print("\n")
 798                return task
 799
 800            except Exception as e:
 801                if isinstance(e, optuna.exceptions.TrialPruned):
 802                    raise e
 803                else:
 804                    # if str(e) != f"The {episode_name} episode name is already in use!":
 805                    #     self.remove_episode(episode_name)
 806                    raise RuntimeError(f"Episode {episode_name} could not run")
 807
 808    def run_episodes(
 809        self,
 810        episode_names: List,
 811        load_episodes: List = None,
 812        parameters_updates: List = None,
 813        load_epochs: List = None,
 814        load_searches: List = None,
 815        load_parameters: List = None,
 816        round_to_binary: List = None,
 817        load_strict: List = None,
 818        force: bool = False,
 819        suppress_name_check: bool = False,
 820        remove_saved_features: bool = False,
 821    ) -> TaskDispatcher:
 822        """
 823        Run multiple episodes in sequence (and re-use previously loaded information)
 824
 825        For each episode, the task parameters are read from the config files and then updated with the
 826        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
 827        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 828        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 829        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 830        data parameters are used.
 831
 832        Parameters
 833        ----------
 834        episode_names : list
 835            a list of strings of episode names
 836        load_episodes : list, optional
 837            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
 838            the new episode will have the same number of runs, each starting with one of the pre-trained models
 839        parameters_updates : list, optional
 840            a list of dictionaries used to update the parameters from the config
 841        load_epochs : list, optional
 842            a list of integers used to specify the epoch to load (if load_episodes is not None)
 843        load_searches : list, optional
 844            a list of strings of hyperparameter search results to load
 845        load_parameters : list, optional
 846            a list of lists of string names of the parameters to load from the searches
 847        round_to_binary : list, optional
 848            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 849        load_strict : list, optional
 850            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
 851            the corresponding episode and differences in parameter name lists and
 852            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
 853            every episode)
 854        force : bool, default False
 855            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
 856        suppress_name_check : bool, default False
 857            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 858            why they are usually forbidden)
 859        remove_saved_features : bool, default False
 860            if `True`, the dataset will be deleted after training
 861        """
 862
 863        task = None
 864        if load_searches is None:
 865            load_searches = [None for _ in episode_names]
 866        if load_episodes is None:
 867            load_episodes = [None for _ in episode_names]
 868        if parameters_updates is None:
 869            parameters_updates = [None for _ in episode_names]
 870        if load_parameters is None:
 871            load_parameters = [None for _ in episode_names]
 872        if load_epochs is None:
 873            load_epochs = [None for _ in episode_names]
 874        if load_strict is None:
 875            load_strict = [True for _ in episode_names]
 876        for (
 877            parameters_update,
 878            episode_name,
 879            load_episode,
 880            load_epoch,
 881            load_search,
 882            load_parameters_list,
 883            load_strict_value,
 884        ) in zip(
 885            parameters_updates,
 886            episode_names,
 887            load_episodes,
 888            load_epochs,
 889            load_searches,
 890            load_parameters,
 891            load_strict,
 892        ):
 893            task = self.run_episode(
 894                episode_name,
 895                load_episode,
 896                parameters_update,
 897                task,
 898                load_epoch,
 899                load_search,
 900                load_parameters_list,
 901                round_to_binary,
 902                load_strict_value,
 903                suppress_name_check=suppress_name_check,
 904                force=force,
 905                remove_saved_features=remove_saved_features,
 906            )
 907        return task
 908
 909    def continue_episode(
 910        self,
 911        episode_name: str,
 912        num_epochs: int = None,
 913        task: TaskDispatcher = None,
 914        n_seeds: int = 1,
 915        remove_saved_features: bool = False,
 916        device: str = "cuda",
 917        num_cpus: int = None,
 918    ) -> TaskDispatcher:
 919        """
 920        Load an older episode and continue running from the latest checkpoint
 921
 922        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
 923
 924        Parameters
 925        ----------
 926        episode_name : str
 927            the name of the episode to continue
 928        num_epochs : int, optional
 929            the new number of epochs
 930        task : TaskDispatcher, optional
 931            a pre-existing task; if provided, the method will update the task instead of creating a new one
 932            (this might save time, mainly on dataset loading)
 933        result_average_interval : int, default 5
 934            the metric are averaged over the last result_average_interval to be stored in the episodes meta file
 935            and displayed by list_episodes() function (the full log is still always available)
 936        n_seeds : int, default 1
 937            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name::run_index`, e.g.
 938            `test_episode::0` and `test_episode::1`
 939        remove_saved_features : bool, default False
 940            if `True`, pre-computed features will be deleted after the run
 941        device : str, default "cuda"
 942            the torch device to use
 943        """
 944
 945        runs = self._episodes().get_runs(episode_name)
 946        for run in runs:
 947            print(f"TRAINING {run}")
 948            if num_epochs is None and not self._episode(run).unfinished():
 949                continue
 950            parameters_update = {
 951                "training": {
 952                    "num_epochs": num_epochs,
 953                    "device": device,
 954                },
 955                "general": {"num_cpus": num_cpus},
 956            }
 957            task, parameters = self._make_task_training(
 958                run,
 959                load_episode=run,
 960                parameters_update=parameters_update,
 961                continuing=True,
 962                task=task,
 963            )
 964            time_start = time.time()
 965            logs = task.train()
 966            time_end = time.time()
 967            old_time = self._training_time(run)
 968            if not np.isnan(old_time):
 969                time_end += old_time
 970                time_total = time_end - time_start
 971                hours = int(time_total // 3600)
 972                time_total -= hours * 3600
 973                minutes = int(time_total // 60)
 974                time_total -= minutes * 60
 975                seconds = int(time_total)
 976                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 977            else:
 978                training_time = np.nan
 979            self._save_episode(
 980                run,
 981                parameters,
 982                task.behaviors_dict(),
 983                suppress_validation=True,
 984                training_time=training_time,
 985                norm_stats=task.get_normalization_stats(),
 986            )
 987            self._update_episode_results(run, logs)
 988            print("\n")
 989        if len(runs) < n_seeds:
 990            for i in range(len(runs), n_seeds):
 991                self.run_episode(
 992                    f"{episode_name}::{i}",
 993                    parameters_update=self._episodes().load_parameters(runs[0]),
 994                    task=task,
 995                    suppress_name_check=True,
 996                )
 997        if remove_saved_features:
 998            self._remove_stores(parameters)
 999        return task
1000
1001    def run_default_hyperparameter_search(
1002        self,
1003        search_name: str,
1004        model_name: str = None,
1005        metric: str = "f1",
1006        best_n: int = 3,
1007        direction: str = "maximize",
1008        load_episode: str = None,
1009        load_epoch: int = None,
1010        load_strict: bool = True,
1011        prune: bool = True,
1012        force: bool = False,
1013        remove_saved_features: bool = False,
1014        overlap: float = 0,
1015        num_epochs: int = 50,
1016        test_frac: float = 0,
1017        n_trials=150,
1018        device: str = None,
1019    ):
1020        """
1021        Run an optuna hyperparameter search with default parameters for a model
1022
1023        For the vast majority of cases, optimizing the default parameters should be enough.
1024        Check out `dlc2action.options.model_hyperparameters` for the lists of parameters.
1025        There are also options to set overlap, test fraction and number of epochs parameters for the search without
1026        modifying the project config files. However, if you want something more complex, look into
1027        `Project.run_hyperparameter_search`.
1028
1029        The task parameters are read from the config files and updated with the parameters_update dictionary.
1030        The model can be either initialized from scratch or loaded from a previously run episode.
1031        For each trial, the objective metric is averaged over a few best epochs.
1032
1033        Parameters
1034        ----------
1035        search_name : str
1036            the name of the search to store it in the meta files and load in run_episode
1037        model_name : str, optional
1038            the name of the model (by default loaded from the project settings, see `project.help('models')` for options)
1039        metric : str, default f1
1040            the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to
1041            `"none"` in the config files, it will be reset to `"macro"` for the search; see `project.help('metrics')` for options
1042        n_trials : int, default 20
1043            the number of optimization trials to run
1044        best_n : int, default 1
1045            the number of epochs to average the metric; if 0, the last value is taken
1046        parameters_update : dict, optional
1047            the parameters update dictionary
1048        direction : {'maximize', 'minimize'}
1049            optimization direction
1050        load_episode : str, optional
1051            the name of the episode to load the model from
1052        load_epoch : int, optional
1053            the epoch to load the model from (if not provided, the last checkpoint is used)
1054        prune : bool, default False
1055            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1056            (with optuna HyperBand pruner)
1057        force : bool, default False
1058            if `True`, existing searches with the same name will be overwritten
1059        remove_saved_features : bool, default False
1060            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1061        device : str, optional
1062            cuda:{i} or cpu, if not given it is read from the default parameters
1063
1064        Returns
1065        -------
1066        dict
1067            a dictionary of best parameters
1068        """
1069
1070        if model_name is None:
1071            model_name = self._read_parameters()["general"]["model_name"]
1072        if model_name not in options.model_hyperparameters:
1073            raise ValueError(
1074                f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()"
1075            )
1076        pars = {
1077            "general": {
1078                "overlap": overlap,
1079                "model_name": model_name,
1080                "metric_functions": {metric},
1081            },
1082            "training": {"num_epochs": num_epochs},
1083        }
1084        if test_frac is not None:
1085            pars["training"]["test_frac"] = test_frac
1086        if not metric.split("_")[-1].isnumeric():
1087            project_pars = self._read_parameters()
1088            if project_pars["metrics"][metric].get("average") == "none":
1089                pars["metrics"] = {metric: {"average": "macro"}}
1090        if device is not None:
1091            pars["training"]["device"] = device
1092        return self.run_hyperparameter_search(
1093            search_name=search_name,
1094            search_space=options.model_hyperparameters[model_name],
1095            metric=metric,
1096            n_trials=n_trials,
1097            best_n=best_n,
1098            parameters_update=pars,
1099            direction=direction,
1100            load_episode=load_episode,
1101            load_epoch=load_epoch,
1102            load_strict=load_strict,
1103            prune=prune,
1104            force=force,
1105            remove_saved_features=remove_saved_features,
1106        )
1107
1108    def run_hyperparameter_search(
1109        self,
1110        search_name: str,
1111        search_space: Dict,
1112        metric: str = "f1",
1113        n_trials: int = 20,
1114        best_n: int = 1,
1115        parameters_update: Dict = None,
1116        direction: str = "maximize",
1117        load_episode: str = None,
1118        load_epoch: int = None,
1119        load_strict: bool = True,
1120        prune: bool = False,
1121        force: bool = False,
1122        remove_saved_features: bool = False,
1123    ) -> Dict:
1124        """
1125        Run an optuna hyperparameter search
1126
1127        For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`.
1128
1129        To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is
1130        a dictionary where keys are model names and values are default search spaces.
1131
1132        The task parameters are read from the config files and updated with the parameters_update dictionary.
1133        The model can be either initialized from scratch or loaded from a previously run episode.
1134        For each trial, the objective metric is averaged over a few best epochs.
1135
1136        Parameters
1137        ----------
1138        search_name : str
1139            the name of the search to store it in the meta files and load in run_episode
1140        search_space : dict
1141            a dictionary representing the search space; of this general structure:
1142            {'group/param_name': ('float/int/float_log/int_log', start, end),
1143            'group/param_name': ('categorical', [choices])}, e.g.
1144            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
1145            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])};
1146        metric : str, default f1
1147            the metric to maximize/minimize (see direction)
1148        n_trials : int, default 20
1149            the number of optimization trials to run
1150        best_n : int, default 1
1151            the number of epochs to average the metric; if 0, the last value is taken
1152        parameters_update : dict, optional
1153            the parameters update dictionary
1154        direction : {'maximize', 'minimize'}
1155            optimization direction
1156        load_episode : str, optional
1157            the name of the episode to load the model from
1158        load_epoch : int, optional
1159            the epoch to load the model from (if not provided, the last checkpoint is used)
1160        prune : bool, default False
1161            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1162            (with optuna HyperBand pruner)
1163        force : bool, default False
1164            if `True`, existing searches with the same name will be overwritten
1165        remove_saved_features : bool, default False
1166            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1167
1168        Returns
1169        -------
1170        dict
1171            a dictionary of best parameters
1172        """
1173
1174        self._check_search_validity(search_name, force=force)
1175        print(f"SEARCH {search_name}")
1176        self.remove_episode(f"_{search_name}")
1177        if parameters_update is None:
1178            parameters_update = {}
1179        parameters_update = self._update(
1180            parameters_update, {"general": {"metric_functions": {metric}}}
1181        )
1182        parameters = self._make_parameters(
1183            f"_{search_name}",
1184            load_episode,
1185            parameters_update,
1186            parameters_update_second={"training": {"model_save_path": None}},
1187            load_epoch=load_epoch,
1188            load_strict=load_strict,
1189        )
1190        task = None
1191
1192        if prune:
1193            pruner = optuna.pruners.HyperbandPruner()
1194        else:
1195            pruner = optuna.pruners.NopPruner()
1196        study = optuna.create_study(direction=direction, pruner=pruner)
1197        runner = _Runner(
1198            search_space=search_space,
1199            load_episode=load_episode,
1200            load_epoch=load_epoch,
1201            metric=metric,
1202            average=best_n,
1203            task=task,
1204            remove_saved_features=remove_saved_features,
1205            project=self,
1206            search_name=search_name,
1207        )
1208        study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials)
1209        search_path = self._search_path(search_name)
1210        os.mkdir(search_path)
1211        fig = optuna.visualization.plot_contour(study)
1212        plotly.offline.plot(
1213            fig, filename=os.path.join(search_path, f"{search_name}_contour.html")
1214        )
1215        fig = optuna.visualization.plot_param_importances(study)
1216        plotly.offline.plot(
1217            fig, filename=os.path.join(search_path, f"{search_name}_importances.html")
1218        )
1219        best_params = study.best_params
1220        best_value = study.best_value
1221        self._save_search(
1222            search_name,
1223            parameters,
1224            n_trials,
1225            best_params,
1226            best_value,
1227            metric,
1228            search_space,
1229        )
1230        self.remove_episode(f"_{search_name}")
1231        runner.clean()
1232        print(f"best parameters: {best_params}")
1233        print("\n")
1234        return best_params
1235
1236    def run_prediction(
1237        self,
1238        prediction_name: str,
1239        episode_names: List,
1240        load_epochs: List = None,
1241        parameters_update: Dict = None,
1242        augment_n: int = 10,
1243        data_path: str = None,
1244        mode: str = "all",
1245        file_paths: Set = None,
1246        remove_saved_features: bool = False,
1247        submission: bool = False,
1248        frame_number_map_file: str = None,
1249        force: bool = False,
1250        embedding: bool = False,
1251    ) -> None:
1252        """
1253        Load models from previously run episodes to generate a prediction
1254
1255        The probabilities predicted by the models are averaged.
1256        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1257        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1258        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1259        are the prediction arrays.
1260
1261        Parameters
1262        ----------
1263        prediction_name : str
1264            the name of the prediction
1265        episode_names : list
1266            a list of string episode names to load the models from
1267        load_epochs : list, optional
1268            a list of integer epoch indices to load the model from; if None, the last ones are used
1269        parameters_update : dict, optional
1270            a dictionary of parameter updates
1271        augment_n : int, default 10
1272            the number of augmentations to average over
1273        data_path : str, optional
1274            the data path to run the prediction for
1275        mode : {'all', 'test', 'val', 'train'}
1276            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1277        file_paths : set, optional
1278            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1279            for
1280        remove_saved_features : bool, default False
1281            if `True`, pre-computed features will be deleted
1282        submission : bool, default False
1283            if `True`, a MABe-22 style submission file is generated
1284        frame_number_map_file : str, optional
1285            path to the frame number map file
1286        force : bool, default False
1287            if `True`, existing prediction with this name will be overwritten
1288        """
1289
1290        self._check_prediction_validity(prediction_name, force=force)
1291        print(f"PREDICTION {prediction_name}")
1292        if submission:
1293            task = ...
1294            # TODO: add submission option to _make_prediction
1295            predicted = task.generate_submission(
1296                frame_number_map_file=frame_number_map_file,
1297                dataset=mode,
1298                augment_n=augment_n,
1299            )
1300            folder = os.path.join(
1301                self.project_path,
1302                "results",
1303                "predictions",
1304                f"{prediction_name}",
1305            )
1306            filename = os.path.join(folder, f"{prediction_name}.npy")
1307            np.save(filename, predicted, allow_pickle=True)
1308        else:
1309            try:
1310                (
1311                    task,
1312                    parameters,
1313                    mode,
1314                    prediction,
1315                    inference_time,
1316                ) = self._make_prediction(
1317                    prediction_name,
1318                    episode_names,
1319                    load_epochs,
1320                    parameters_update,
1321                    data_path,
1322                    file_paths,
1323                    mode,
1324                    augment_n,
1325                    evaluate=False,
1326                    embedding=embedding,
1327                )
1328                predicted = task.dataset(mode).generate_full_length_prediction(
1329                    prediction
1330                )
1331            except ValueError:
1332                (
1333                    task,
1334                    parameters,
1335                    mode,
1336                    predicted,
1337                    inference_time,
1338                ) = self._aggregate_predictions(
1339                    prediction_name,
1340                    episode_names,
1341                    load_epochs,
1342                    parameters_update,
1343                    data_path,
1344                    file_paths,
1345                    mode,
1346                    augment_n,
1347                    evaluate=False,
1348                    embedding=embedding,
1349                )
1350            folder = self.prediction_path(prediction_name)
1351            os.mkdir(folder)
1352            for video_id, prediction in predicted.items():
1353                with open(
1354                    os.path.join(
1355                        folder, video_id + f"_{prediction_name}_prediction.pickle"
1356                    ),
1357                    "wb",
1358                ) as f:
1359                    prediction["min_frames"], prediction["max_frames"] = task.dataset(
1360                        mode
1361                    ).get_min_max_frames(video_id)
1362                    behavior_indices = sorted(
1363                        [key for key in task.behaviors_dict() if key != -100]
1364                    )
1365                    prediction["behaviors"] = [
1366                        task.behaviors_dict()[key] for key in behavior_indices
1367                    ]
1368                    pickle.dump(prediction, f)
1369        if remove_saved_features:
1370            self._remove_stores(parameters)
1371        self._save_prediction(
1372            prediction_name,
1373            parameters,
1374            task.behaviors_dict(),
1375            embedding,
1376            inference_time,
1377        )
1378        print("\n")
1379
1380    def evaluate_prediction(
1381        self,
1382        prediction_name: str,
1383        parameters_update: Dict = None,
1384        data_path: str = None,
1385        file_paths: Set = None,
1386        mode: str = None,
1387        remove_saved_features: bool = False,
1388    ) -> Tuple[float, dict]:
1389
1390        with open(
1391            os.path.join(
1392                self.project_path, "results", "predictions", f"{prediction_name}.pickle"
1393            ),
1394            "rb",
1395        ) as f:
1396            prediction = pickle.load(f)
1397        if parameters_update is None:
1398            parameters_update = {}
1399        parameters_update = self._update(
1400            self._predictions().load_parameters(prediction_name), parameters_update
1401        )
1402        parameters_update.pop("model")
1403        task, parameters, mode = self._make_task_prediction(
1404            "_",
1405            load_episode=None,
1406            parameters_update=parameters_update,
1407            data_path=data_path,
1408            file_paths=file_paths,
1409            mode=mode,
1410        )
1411        results = task.evaluate_prediction(prediction, data=mode)
1412        if remove_saved_features:
1413            self._remove_stores(parameters)
1414        print("\n")
1415        return results
1416
1417    def evaluate(
1418        self,
1419        episode_names: List,
1420        load_epochs: List = None,
1421        augment_n: int = 0,
1422        data_path: str = None,
1423        file_paths: Set = None,
1424        mode: str = None,
1425        parameters_update: Dict = None,
1426        multiple_episode_policy: str = "average",
1427        remove_saved_features: bool = False,
1428        skip_updating_meta: bool = True,
1429    ) -> Dict:
1430        """
1431        Load one or several models from previously run episodes to make an evaluation
1432
1433        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1434
1435        Parameters
1436        ----------
1437        episode_names : list
1438            a list of string episode names to load the models from
1439        load_epochs : list, optional
1440            a list of integer epoch indices to load the model from; if None, the last ones are used
1441        augment_n : int, default 0
1442            the number of augmentations to average over
1443        data_path : str, optional
1444            the data path to run the prediction for
1445        file_paths : set, optional
1446            a set of files to run the prediction for
1447        mode : {'test', 'val', 'train', 'all'}
1448            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1449            by default 'test' if test subset is not empty and 'val' otherwise)
1450        parameters_update : dict, optional
1451            a dictionary with parameter updates (cannot change model parameters)
1452        remove_saved_features : bool, default False
1453            if `True`, the dataset will be deleted
1454
1455        Returns
1456        -------
1457        metric : dict
1458            a dictionary of average values of metric functions
1459        """
1460
1461        names = []
1462        for episode_name in episode_names:
1463            names += self._episodes().get_runs(episode_name)
1464        if len(set(episode_names)) == 1:
1465            print(f"EVALUATION {episode_names[0]}")
1466        else:
1467            print(f"EVALUATION {episode_names}")
1468        if len(names) > 1:
1469            evaluate = True
1470        else:
1471            evaluate = False
1472        if multiple_episode_policy == "average":
1473            try:
1474                (
1475                    task,
1476                    parameters,
1477                    mode,
1478                    prediction,
1479                    inference_time,
1480                ) = self._make_prediction(
1481                    "_",
1482                    episode_names,
1483                    load_epochs,
1484                    parameters_update,
1485                    mode=mode,
1486                    data_path=data_path,
1487                    file_paths=file_paths,
1488                    augment_n=augment_n,
1489                    evaluate=evaluate,
1490                )
1491            except:
1492                (
1493                    task,
1494                    parameters,
1495                    mode,
1496                    prediction,
1497                    inference_time,
1498                ) = self._aggregate_predictions(
1499                    "_",
1500                    episode_names,
1501                    load_epochs,
1502                    parameters_update,
1503                    mode=mode,
1504                    data_path=data_path,
1505                    file_paths=file_paths,
1506                    augment_n=augment_n,
1507                    evaluate=evaluate,
1508                )
1509            print("AGGREGATED:")
1510            _, results = task.evaluate_prediction(prediction, data=mode)
1511            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1512                self._update_episode_metrics(names[0], results)
1513        elif multiple_episode_policy == "statistics":
1514            values = defaultdict(lambda: [])
1515            task = None
1516            for name in names:
1517                (
1518                    task,
1519                    parameters,
1520                    mode,
1521                    prediction,
1522                    inference_time,
1523                ) = self._make_prediction(
1524                    "_",
1525                    [name],
1526                    load_epochs,
1527                    parameters_update,
1528                    mode=mode,
1529                    data_path=data_path,
1530                    file_paths=file_paths,
1531                    augment_n=augment_n,
1532                    evaluate=evaluate,
1533                    task=task,
1534                )
1535                _, metrics = task.evaluate_prediction(prediction, data=mode)
1536                for name, value in metrics.items():
1537                    values[name].append(value)
1538                if mode == "val" and not skip_updating_meta:
1539                    self._update_episode_metrics(name, metrics)
1540            results = defaultdict(lambda: {})
1541            mean_string = ""
1542            std_string = ""
1543            for key, value_list in values.items():
1544                results[key]["mean"] = np.mean(value_list)
1545                results[key]["std"] = np.std(value_list)
1546                mean_string += f"{key} {np.mean(value_list):.3f}, "
1547                std_string += f"{key} {np.std(value_list):.3f}, "
1548            print("MEAN:")
1549            print(mean_string)
1550            print("STD:")
1551            print(std_string)
1552        else:
1553            raise ValueError(
1554                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1555                f"from ['average', 'statistics']"
1556            )
1557        if len(names) > 0 and remove_saved_features:
1558            self._remove_stores(parameters)
1559        print(f"Inference time: {inference_time}")
1560        print("\n")
1561        return results
1562
1563    def _generate_similarity_score(
1564        self,
1565        prediction_name: str,
1566        target_video_id: str,
1567        target_clip: str,
1568        target_start: int,
1569        target_end: int,
1570    ) -> Dict:
1571        with open(
1572            os.path.join(
1573                self.project_path,
1574                "results",
1575                "predictions",
1576                f"{prediction_name}.pickle",
1577            ),
1578            "rb",
1579        ) as f:
1580            prediction = pickle.load(f)
1581        target = prediction[target_video_id][target_clip][:, target_start:target_end]
1582        score_dict = defaultdict(lambda: {})
1583        for video_id in prediction:
1584            for clip_id in prediction[video_id]:
1585                score_dict[video_id][clip_id] = torch.cdist(
1586                    target.T, prediction[video_id][score_dict].T
1587                ).min(0)
1588        return score_dict
1589
1590    def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict:
1591        interval_address = {}
1592        interval_value = {}
1593        s = 0
1594        n = 0
1595        for video_id, video_dict in score_dict.items():
1596            for clip_id, value in video_dict.items():
1597                s += value.mean()
1598                n += 1
1599        mean_value = s / n
1600        alpha = 1.75
1601        for it in range(10):
1602            id = 0
1603            interval_address = {}
1604            interval_value = {}
1605            for video_id, video_dict in score_dict.items():
1606                for clip_id, value in video_dict.items():
1607                    res_indices_start, res_indices_end = apply_threshold(
1608                        value,
1609                        threshold=(2 - alpha * (0.9**it)) * mean_value,
1610                        low=True,
1611                        error_mask=None,
1612                        min_frames=min_length,
1613                        smooth_interval=0,
1614                    )
1615                    for start, end in zip(res_indices_start, res_indices_end):
1616                        interval_address[id] = [video_id, clip_id, start, end]
1617                        interval_value[id] = score_dict[video_id][clip_id][
1618                            start:end
1619                        ].mean()
1620                        id += 1
1621            if len(interval_address) >= n_intervals:
1622                break
1623        if len(interval_address) < n_intervals:
1624            warnings.warn(
1625                f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals"
1626            )
1627        sorted_intervals = sorted(
1628            interval_value.items(), key=lambda x: x[1], reverse=True
1629        )
1630        output_intervals = [
1631            interval_address[x[0]]
1632            for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)]
1633        ]
1634        output = defaultdict(lambda: [])
1635        for video_id, clip_id, start, end in output_intervals:
1636            output[video_id].append([start, end, clip_id])
1637        return output
1638
1639    def list_episodes(
1640        self,
1641        episode_names: List = None,
1642        value_filter: str = "",
1643        display_parameters: List = None,
1644        print_results: bool = True,
1645    ) -> pd.DataFrame:
1646        """
1647        Get a filtered pandas dataframe with episode metadata
1648
1649        Parameters
1650        ----------
1651        episode_names : list
1652            a list of strings of episode names
1653        value_filter : str
1654            a string of filters to apply; of this general structure:
1655            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
1656            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
1657        display_parameters : list
1658            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1659        print_results : bool, default True
1660            if True, the result will be printed to standard output
1661
1662        Returns
1663        -------
1664        pd.DataFrame
1665            the filtered dataframe
1666        """
1667
1668        episodes = self._episodes().list_episodes(
1669            episode_names, value_filter, display_parameters
1670        )
1671        if print_results:
1672            print("TRAINING EPISODES")
1673            print(episodes)
1674            print("\n")
1675        return episodes
1676
1677    def list_predictions(
1678        self,
1679        episode_names: List = None,
1680        value_filter: str = "",
1681        display_parameters: List = None,
1682        print_results: bool = True,
1683    ) -> pd.DataFrame:
1684        """
1685        Get a filtered pandas dataframe with prediction metadata
1686
1687        Parameters
1688        ----------
1689        episode_names : list
1690            a list of strings of episode names
1691        value_filter : str
1692            a string of filters to apply; of this general structure:
1693            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
1694            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
1695        display_parameters : list
1696            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1697        print_results : bool, default True
1698            if True, the result will be printed to standard output
1699
1700        Returns
1701        -------
1702        pd.DataFrame
1703            the filtered dataframe
1704        """
1705
1706        predictions = self._predictions().list_episodes(
1707            episode_names, value_filter, display_parameters
1708        )
1709        if print_results:
1710            print("PREDICTIONS")
1711            print(predictions)
1712            print("\n")
1713        return predictions
1714
1715    def list_searches(
1716        self,
1717        search_names: List = None,
1718        value_filter: str = "",
1719        display_parameters: List = None,
1720        print_results: bool = True,
1721    ) -> pd.DataFrame:
1722        """
1723        Get a filtered pandas dataframe with hyperparameter search metadata
1724
1725        Parameters
1726        ----------
1727        search_names : list
1728            a list of strings of search names
1729        value_filter : str
1730            a string of filters to apply; of this general structure:
1731            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
1732            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
1733        display_parameters : list
1734            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1735        print_results : bool, default True
1736            if True, the result will be printed to standard output
1737
1738        Returns
1739        -------
1740        pd.DataFrame
1741            the filtered dataframe
1742        """
1743
1744        searches = self._searches().list_episodes(
1745            search_names, value_filter, display_parameters
1746        )
1747        if print_results:
1748            print("SEARCHES")
1749            print(searches)
1750            print("\n")
1751        return searches
1752
1753    def get_best_parameters(
1754        self,
1755        search_name: str,
1756        round_to_binary: List = None,
1757    ):
1758        params, model = self._searches().get_best_params(
1759            search_name, round_to_binary=round_to_binary
1760        )
1761        params = self._update(params, {"general": {"model_name": model}})
1762        return params
1763
1764    def list_best_parameters(
1765        self, search_name: str, print_results: bool = True
1766    ) -> Dict:
1767        """
1768        Get the raw dictionary of best parameters found by a search
1769
1770        Parameters
1771        ----------
1772        search_name : str
1773            the name of the search
1774        print_results : bool, default True
1775            if True, the result will be printed to standard output
1776
1777        Returns
1778        -------
1779        best_params : dict
1780            a dictionary of the best parameters where the keys are in '{group}/{name}' format
1781        """
1782
1783        params = self._searches().get_best_params_raw(search_name)
1784        if print_results:
1785            print(f"SEARCH RESULTS {search_name}")
1786            for k, v in params.items():
1787                print(f"{k}: {v}")
1788            print("\n")
1789        return params
1790
1791    def plot_episodes(
1792        self,
1793        episode_names: List,
1794        metrics: List,
1795        modes: List = None,
1796        title: str = None,
1797        episode_labels: List = None,
1798        save_path: str = None,
1799        add_hlines: List = None,
1800        epoch_limits: List = None,
1801        colors: List = None,
1802        add_highpoint_hlines: bool = False,
1803    ) -> None:
1804        """
1805        Plot episode training curves
1806
1807        Parameters
1808        ----------
1809        episode_names : list
1810            a list of episode names to plot; to plot to episodes in one line combine them in a list
1811            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
1812        metrics : list
1813            a list of metric to plot
1814        modes : list, optional
1815            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
1816        title : str, optional
1817            title for the plot
1818        episode_labels : list, optional
1819            a list of strings used to label the curves (has to be the same length as episode_names)
1820        save_path : str, optional
1821            the path to save the resulting plot
1822        add_hlines : list, optional
1823            a list of float values (or (value, label) tuples) to mark with horizontal lines
1824        colors: list, optional
1825            a list of matplotlib colors
1826        add_highpoint_hlines : bool, default False
1827            if `True`, horizontal lines will be added at the highest value of each episode
1828        """
1829
1830        if modes is None:
1831            modes = ["val"]
1832        if add_hlines is None:
1833            add_hlines = []
1834        logs = []
1835        epochs = []
1836        labels = []
1837        if episode_labels is not None:
1838            assert len(episode_labels) == len(episode_names)
1839        for name_i, name in enumerate(episode_names):
1840            log_params = product(metrics, modes)
1841            for metric, mode in log_params:
1842                if episode_labels is not None:
1843                    label = episode_labels[name_i]
1844                else:
1845                    label = deepcopy(name)
1846                if len(modes) != 1:
1847                    label += f"_{mode}"
1848                if len(metrics) != 1:
1849                    label += f"_{metric}"
1850                labels.append(label)
1851                if isinstance(name, Iterable) and not isinstance(name, str):
1852                    epoch_list = defaultdict(lambda: [])
1853                    multi_logs = defaultdict(lambda: [])
1854                    for i, n in enumerate(name):
1855                        runs = self._episodes().get_runs(n)
1856                        if len(runs) > 1:
1857                            for run in runs:
1858                                index = run.split("::")[-1]
1859                                if multi_logs[index] == []:
1860                                    if multi_logs["null"] is None:
1861                                        raise RuntimeError(
1862                                            "The run indices are not consistent across episodes!"
1863                                        )
1864                                    else:
1865                                        multi_logs[index] += multi_logs["null"]
1866                                multi_logs[index] += list(
1867                                    self._episode(run).get_metric_log(mode, metric)
1868                                )
1869                                start = (
1870                                    0
1871                                    if len(epoch_list[index]) == 0
1872                                    else epoch_list[index][-1]
1873                                )
1874                                epoch_list[index] += [
1875                                    x + start
1876                                    for x in self._episode(run).get_epoch_list(mode)
1877                                ]
1878                            multi_logs["null"] = None
1879                        else:
1880                            if len(multi_logs.keys()) > 1:
1881                                raise RuntimeError(
1882                                    "Cannot plot a single-run episode after a multi-run episode!"
1883                                )
1884                            multi_logs["null"] += list(
1885                                self._episode(n).get_metric_log(mode, metric)
1886                            )
1887                            start = (
1888                                0
1889                                if len(epoch_list["null"]) == 0
1890                                else epoch_list["null"][-1]
1891                            )
1892                            epoch_list["null"] += [
1893                                x + start for x in self._episode(n).get_epoch_list(mode)
1894                            ]
1895                    if len(multi_logs.keys()) == 1:
1896                        log = multi_logs["null"]
1897                        epochs.append(epoch_list["null"])
1898                    else:
1899                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
1900                        epochs.append(
1901                            tuple([v for k, v in epoch_list.items() if k != "null"])
1902                        )
1903                else:
1904                    runs = self._episodes().get_runs(name)
1905                    if len(runs) > 1:
1906                        log = []
1907                        for run in runs:
1908                            tracked_metrics = self._episode(run).get_metrics()
1909                            if metric in tracked_metrics:
1910                                log.append(
1911                                    list(
1912                                        self._episode(run).get_metric_log(mode, metric)
1913                                    )
1914                                )
1915                            else:
1916                                relevant = []
1917                                for m in tracked_metrics:
1918                                    m_split = m.split("_")
1919                                    if (
1920                                        "_".join(m_split[:-1]) == metric
1921                                        and m_split[-1].isnumeric()
1922                                    ):
1923                                        relevant.append(m)
1924                                if len(relevant) == 0:
1925                                    raise ValueError(
1926                                        f"The {metric} metric was not tracked at {run}"
1927                                    )
1928                                arr = 0
1929                                for m in relevant:
1930                                    arr += self._episode(run).get_metric_log(mode, m)
1931                                arr /= len(relevant)
1932                                log.append(list(arr))
1933                        log = tuple(log)
1934                        epochs.append(
1935                            tuple(
1936                                [
1937                                    self._episode(run).get_epoch_list(mode)
1938                                    for run in runs
1939                                ]
1940                            )
1941                        )
1942                    else:
1943                        tracked_metrics = self._episode(name).get_metrics()
1944                        if metric in tracked_metrics:
1945                            log = list(self._episode(name).get_metric_log(mode, metric))
1946                        else:
1947                            relevant = []
1948                            for m in tracked_metrics:
1949                                m_split = m.split("_")
1950                                if (
1951                                    "_".join(m_split[:-1]) == metric
1952                                    and m_split[-1].isnumeric()
1953                                ):
1954                                    relevant.append(m)
1955                            if len(relevant) == 0:
1956                                raise ValueError(
1957                                    f"The {metric} metric was not tracked at {name}"
1958                                )
1959                            arr = 0
1960                            for m in relevant:
1961                                arr += self._episode(name).get_metric_log(mode, m)
1962                            arr /= len(relevant)
1963                            log = list(arr)
1964                        epochs.append(self._episode(name).get_epoch_list(mode))
1965                logs.append(log)
1966        # if episode_labels is not None:
1967        #     print(f'{len(episode_labels)=}, {len(logs)=}')
1968        #     if len(episode_labels) != len(logs):
1969
1970        #         raise ValueError(
1971        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
1972        #             f"curves ({len(logs)})!"
1973        #         )
1974        #     else:
1975        #         labels = episode_labels
1976        if colors is None:
1977            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
1978        if len(colors) != len(logs):
1979            raise ValueError(
1980                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
1981            )
1982        plt.figure()
1983        length = 0
1984        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
1985            if type(log) is list:
1986                if len(log) > length:
1987                    length = len(log)
1988                plt.plot(
1989                    epoch_list,
1990                    log,
1991                    label=label,
1992                    color=color,
1993                )
1994                if add_highpoint_hlines:
1995                    plt.axhline(np.max(log), linestyle="dashed", color=color)
1996            else:
1997                for l, xx in zip(log, epoch_list):
1998                    if len(l) > length:
1999                        length = len(l)
2000                    plt.plot(
2001                        xx,
2002                        l,
2003                        color=color,
2004                        alpha=0.2,
2005                    )
2006                if not all([len(x) == len(log[0]) for x in log]):
2007                    warnings.warn(
2008                        f"Got logs with unequal lengths in parallel runs for {label}"
2009                    )
2010                    log = list(log)
2011                    epoch_list = list(epoch_list)
2012                    for i, x in enumerate(epoch_list):
2013                        to_remove = []
2014                        for j, y in enumerate(x[1:]):
2015                            if y <= x[j - 1]:
2016                                y_ind = x.index(y)
2017                                to_remove += list(range(y_ind, j))
2018                        epoch_list[i] = [
2019                            y for j, y in enumerate(x) if j not in to_remove
2020                        ]
2021                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2022                    length = min([len(x) for x in log])
2023                    for i in range(len(log)):
2024                        log[i] = log[i][:length]
2025                        epoch_list[i] = epoch_list[i][:length]
2026                    if not all([x == epoch_list[0] for x in epoch_list]):
2027                        raise RuntimeError(
2028                            f"Got different epoch indices in parallel runs for {label}"
2029                        )
2030                mean = np.array(log).mean(0)
2031                plt.plot(
2032                    epoch_list[0],
2033                    mean,
2034                    label=label,
2035                    color=color,
2036                )
2037                if add_highpoint_hlines:
2038                    plt.axhline(np.max(mean), linestyle="dashed", color=color)
2039        for x in add_hlines:
2040            label = None
2041            if isinstance(x, Iterable):
2042                x, label = x
2043            plt.axhline(x, label=label)
2044            plt.xlim((0, length))
2045
2046        plt.legend()
2047        plt.xlabel("epochs")
2048        if len(metrics) == 1:
2049            plt.ylabel(metrics[0])
2050        else:
2051            plt.ylabel("value")
2052        if title is None:
2053            if len(episode_names) == 1:
2054                title = episode_names[0]
2055            elif len(metrics) == 1:
2056                title = metrics[0]
2057        if epoch_limits is not None:
2058            plt.xlim(epoch_limits)
2059        if title is not None:
2060            plt.title(title)
2061        plt.show()
2062        if save_path is not None:
2063            plt.savefig(save_path)
2064
2065    def update_parameters(
2066        self,
2067        parameters_update: Dict = None,
2068        load_search: str = None,
2069        load_parameters: List = None,
2070        round_to_binary: List = None,
2071    ) -> None:
2072        """
2073        Update the parameters in the project config files
2074
2075        Parameters
2076        ----------
2077        parameters_update : dict, optional
2078            a dictionary of parameter updates
2079        load_search : str, optional
2080            the name of hyperparameter search results to load to config
2081        load_parameters : list, optional
2082            a list of lists of string names of the parameters to load from the searches
2083        round_to_binary : list, optional
2084            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2085        """
2086
2087        keys = [
2088            "general",
2089            "losses",
2090            "metrics",
2091            "ssl",
2092            "training",
2093            "data",
2094        ]
2095        parameters = self._read_parameters(catch_blanks=False)
2096        if parameters_update is not None:
2097            if "model" in parameters_update:
2098                model_params = parameters_update.pop("model")
2099            else:
2100                model_params = None
2101            if "features" in parameters_update:
2102                feat_params = parameters_update.pop("features")
2103            else:
2104                feat_params = None
2105            if "augmentations" in parameters_update:
2106                aug_params = parameters_update.pop("augmentations")
2107            else:
2108                aug_params = None
2109            parameters = self._update(parameters, parameters_update)
2110            model_name = parameters["general"]["model_name"]
2111            parameters["model"] = self._open_yaml(
2112                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2113            )
2114            if model_params is not None:
2115                parameters["model"] = self._update(parameters["model"], model_params)
2116            feat_name = parameters["general"]["feature_extraction"]
2117            parameters["features"] = self._open_yaml(
2118                os.path.join(
2119                    self.project_path, "config", "features", f"{feat_name}.yaml"
2120                )
2121            )
2122            if feat_params is not None:
2123                parameters["features"] = self._update(
2124                    parameters["features"], feat_params
2125                )
2126            aug_name = options.extractor_to_transformer[
2127                parameters["general"]["feature_extraction"]
2128            ]
2129            parameters["augmentations"] = self._open_yaml(
2130                os.path.join(
2131                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2132                )
2133            )
2134            if aug_params is not None:
2135                parameters["augmentations"] = self._update(
2136                    parameters["augmentations"], aug_params
2137                )
2138        if load_search is not None:
2139            parameters_update, model_name = self._searches().get_best_params(
2140                load_search, load_parameters, round_to_binary
2141            )
2142            parameters["general"]["model_name"] = model_name
2143            parameters["model"] = self._open_yaml(
2144                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2145            )
2146            parameters = self._update(parameters, parameters_update)
2147        for key in keys:
2148            with open(
2149                os.path.join(self.project_path, "config", f"{key}.yaml"), "w", encoding="utf-8"
2150            ) as f:
2151                YAML().dump(parameters[key], f)
2152        model_name = parameters["general"]["model_name"]
2153        model_path = os.path.join(
2154            self.project_path, "config", "model", f"{model_name}.yaml"
2155        )
2156        with open(model_path, "w", encoding="utf-8") as f:
2157            YAML().dump(parameters["model"], f)
2158        features_name = parameters["general"]["feature_extraction"]
2159        features_path = os.path.join(
2160            self.project_path, "config", "features", f"{features_name}.yaml"
2161        )
2162        with open(features_path, "w", encoding="utf-8") as f:
2163            YAML().dump(parameters["features"], f)
2164        aug_name = options.extractor_to_transformer[features_name]
2165        aug_path = os.path.join(
2166            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2167        )
2168        with open(aug_path, "w", encoding="utf-8") as f:
2169            YAML().dump(parameters["augmentations"], f)
2170
2171    def get_summary(
2172        self,
2173        episode_names: list,
2174        method: str = "last",
2175        average: int = 1,
2176        metrics: List = None,
2177    ) -> Dict:
2178        """
2179        Get a summary of episode statistics
2180
2181        If the episode has multiple runs, the statistics will be aggregated over all of them.
2182
2183        Parameters
2184        ----------
2185        episode_name : str
2186            the name of the episode
2187        method : ["best", "last"]
2188            the method for choosing the epochs
2189        average : int, default 1
2190            the number of epochs to average over (for each run)
2191        metrics : list, optional
2192            a list of metrics
2193
2194        Returns
2195        -------
2196        statistics : dict
2197            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
2198            and 'std' for the standard deviation
2199        """
2200
2201        runs = []
2202        for episode_name in episode_names:
2203            runs_ep = self._episodes().get_runs(episode_name)
2204            if len(runs_ep) == 0:
2205                raise RuntimeError(
2206                    f"There is no {episode_name} episode in the project memory"
2207                )
2208            runs += runs_ep
2209        if metrics is None:
2210            metrics = self._episode(runs[0]).get_metrics()
2211
2212        values = {m: [] for m in metrics}
2213        for run in runs:
2214            for m in metrics:
2215                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
2216                if method == "best":
2217                    log = sorted(log)
2218                    values[m] += list(log[-average:])
2219                elif method == "last":
2220                    if len(log) == 0:
2221                        episodes = self._episodes().data
2222                        if average == 1 and ("results", m) in episodes.columns:
2223                            values[m] += [episodes.loc[run, ("results", m)]]
2224                        else:
2225                            raise RuntimeError(f"Did not find {m} metric for {run} run")
2226                    values[m] += list(log[-average:])
2227                elif method.startswith("epoch"):
2228                    epoch = int(method[5:]) - 1
2229                    pars = self._episodes().load_parameters(run)
2230                    step = int(pars["training"]["validation_interval"])
2231                    values[m] += [log[epoch // step]]
2232                else:
2233                    raise ValueError(
2234                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
2235                    )
2236        statistics = defaultdict(lambda: {})
2237        for m, v in values.items():
2238            statistics[m]["mean"] = np.mean(v)
2239            statistics[m]["std"] = np.std(v)
2240        print(f"SUMMARY {episode_names}")
2241        for m, v in statistics.items():
2242            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
2243        print("\n")
2244        return dict(statistics)
2245
2246    @staticmethod
2247    def remove_project(name: str, projects_path: str = None) -> None:
2248        """
2249        Remove all project files and experiment records and results
2250        """
2251
2252        if projects_path is None:
2253            projects_path = os.path.join(str(Path.home()), "DLC2Action")
2254        project_path = os.path.join(projects_path, name)
2255        if os.path.exists(project_path):
2256            shutil.rmtree(project_path)
2257
2258    def remove_saved_features(
2259        self,
2260        dataset_names: List = None,
2261        exceptions: List = None,
2262        remove_active: bool = False,
2263    ) -> None:
2264        """
2265        Remove saved pre-computed dataset files
2266
2267        By default, all pre-computed features will be deleted.
2268        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
2269        while training or inference is happening though.
2270
2271        Parameters
2272        ----------
2273        dataset_names : list, optional
2274            a list of dataset names to delete (by default all names are added)
2275        exceptions : list, optional
2276            a list of dataset names to not be deleted
2277        remove_active : bool, default False
2278            if `False`, datasets used by unfinished episodes will not be deleted
2279        """
2280
2281        print("Removing datasets...")
2282        if dataset_names is None:
2283            dataset_names = []
2284        if exceptions is None:
2285            exceptions = []
2286        if not remove_active:
2287            exceptions += self._episodes().get_active_datasets()
2288        dataset_path = os.path.join(self.project_path, "saved_datasets")
2289        if os.path.exists(dataset_path):
2290            if dataset_names == []:
2291                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
2292
2293            to_remove = [
2294                x
2295                for x in dataset_names
2296                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
2297            ]
2298            if len(to_remove) > 2:
2299                to_remove = tqdm(to_remove)
2300            for dataset in to_remove:
2301                shutil.rmtree(os.path.join(dataset_path, dataset))
2302            to_remove = [
2303                f"{x}.pickle"
2304                for x in dataset_names
2305                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
2306                and x not in exceptions
2307            ]
2308            for dataset in to_remove:
2309                os.remove(os.path.join(dataset_path, dataset))
2310            names = self._saved_datasets().dataset_names()
2311            self._saved_datasets().remove(names)
2312        print("\n")
2313
2314    def remove_extra_checkpoints(
2315        self, episode_names: List = None, exceptions: List = None
2316    ) -> None:
2317        """
2318        Remove intermediate model checkpoint files (only leave the results of the last epoch)
2319
2320        By default, all intermediate checkpoints will be deleted.
2321        Files in the model folder that are not associated with any record in the meta files are also deleted.
2322
2323        Parameters
2324        ----------
2325        episode_names : list, optional
2326            a list of episode names to clean (by default all names are added)
2327        exceptions : list, optional
2328            a list of episode names to not clean
2329        """
2330
2331        model_path = os.path.join(self.project_path, "results", "model")
2332        try:
2333            all_names = self._episodes().data.index
2334        except:
2335            all_names = os.listdir(model_path)
2336        if episode_names is None:
2337            episode_names = all_names
2338        if exceptions is None:
2339            exceptions = []
2340        to_remove = [x for x in episode_names if x not in exceptions]
2341        folders = os.listdir(model_path)
2342        for folder in folders:
2343            if folder not in all_names:
2344                shutil.rmtree(os.path.join(model_path, folder))
2345            elif folder in to_remove:
2346                files = os.listdir(os.path.join(model_path, folder))
2347                for file in sorted(files)[:-1]:
2348                    os.remove(os.path.join(model_path, folder, file))
2349
2350    def remove_search(self, search_name: str) -> None:
2351        """
2352        Remove a hyperparameter search record
2353
2354        Parameters
2355        ----------
2356        search_name : str
2357            the name of the search to remove
2358        """
2359
2360        self._searches().remove_episode(search_name)
2361        graph_path = os.path.join(self.project_path, "results", "searches", search_name)
2362        if os.path.exists(graph_path):
2363            shutil.rmtree(graph_path)
2364
2365    def remove_prediction(self, prediction_name: str) -> None:
2366        """
2367        Remove a prediction record
2368
2369        Parameters
2370        ----------
2371        prediction_name : str
2372            the name of the prediction to remove
2373        """
2374
2375        self._predictions().remove_episode(prediction_name)
2376        prediction_path = os.path.join(
2377            self.project_path, "results", "predictions", prediction_name
2378        )
2379        if os.path.exists(prediction_path):
2380            shutil.rmtree(prediction_path)
2381
2382    def remove_episode(self, episode_name: str) -> None:
2383        """
2384        Remove all model, logs and metafile records related to an episode
2385
2386        Parameters
2387        ----------
2388        episode_name : str
2389            the name of the episode to remove
2390        """
2391
2392        runs = self._episodes().get_runs(episode_name)
2393        runs.append(episode_name)
2394        for run in runs:
2395            self._episodes().remove_episode(run)
2396            model_path = os.path.join(self.project_path, "results", "model", run)
2397            if os.path.exists(model_path):
2398                shutil.rmtree(model_path)
2399            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
2400            if os.path.exists(log_path):
2401                os.remove(log_path)
2402
2403    def prune_unfinished(self, exceptions: List = None) -> None:
2404        """
2405        Remove all interrupted episodes
2406
2407        Remove all episodes that either don't have a log file or have less epochs in the log file than in
2408        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
2409        currently running!
2410
2411        Parameters
2412        ----------
2413        exceptions : list
2414            the episodes to keep even if they are interrupted
2415
2416        Returns
2417        -------
2418        pruned : list
2419            a list of the episode names that were pruned
2420        """
2421
2422        if exceptions is None:
2423            exceptions = []
2424        unfinished = self._episodes().unfinished_episodes()
2425        unfinished = [x for x in unfinished if x not in exceptions]
2426        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
2427        unfinished += [
2428            x for x in model_folders if x not in self._episodes().list_episodes().index
2429        ]
2430        print(f"PRUNING {unfinished}")
2431        for episode_name in unfinished:
2432            self.remove_episode(episode_name)
2433        print(f"\n")
2434        return unfinished
2435
2436    def prediction_path(self, prediction_name: str) -> str:
2437        """
2438        Get the path where prediction files are saved
2439
2440        Parameters
2441        ----------
2442        prediction_name : str
2443            name of the prediction
2444
2445        Returns
2446        -------
2447        prediction_path : str
2448            the file path
2449        """
2450
2451        return os.path.join(
2452            self.project_path, "results", "predictions", f"{prediction_name}"
2453        )
2454
2455    @classmethod
2456    def print_data_types(cls):
2457        print("DATA TYPES:")
2458        for key, value in cls.data_types().items():
2459            print(f"{key}:")
2460            print(value.__doc__)
2461
2462    @classmethod
2463    def print_annotation_types(cls):
2464        print("ANNOTATION TYPES:")
2465        for key, value in cls.annotation_types().items():
2466            print(f"{key}:")
2467            print(value.__doc__)
2468
2469    @staticmethod
2470    def data_types() -> List:
2471        """
2472        Get available data types
2473
2474        Returns
2475        -------
2476        list
2477            available data types
2478        """
2479
2480        return options.input_stores
2481
2482    @staticmethod
2483    def annotation_types() -> List:
2484        """
2485        Get available annotation types
2486
2487        Returns
2488        -------
2489        list
2490            available annotation types
2491        """
2492
2493        return options.annotation_stores
2494
2495    def _save_mask(self, file: Dict, mask_name: str):
2496        """
2497        Save a mask file
2498        """
2499
2500        if not os.path.exists(self._mask_path()):
2501            os.mkdir(self._mask_path())
2502        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f:
2503            pickle.dump(file, f)
2504
2505    def _load_mask(self, mask_name: str) -> Dict:
2506        """
2507        Load a mask file
2508        """
2509
2510        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f:
2511            data = pickle.load(f)
2512        return data
2513
2514    def _thresholds(self) -> DecisionThresholds:
2515        """
2516        Get the decision thresholds meta object
2517        """
2518
2519        return DecisionThresholds(self._thresholds_path())
2520
2521    def _episodes(self) -> SavedRuns:
2522        """
2523        Get the episodes meta object
2524
2525        Returns
2526        -------
2527        episodes : SavedRuns
2528            the episodes meta object
2529        """
2530
2531        try:
2532            return SavedRuns(self._episodes_path(), self.project_path)
2533        except:
2534            self.load_metadata_backup()
2535            return SavedRuns(self._episodes_path(), self.project_path)
2536
2537    def _predictions(self) -> SavedRuns:
2538        """
2539        Get the predictions meta object
2540
2541        Returns
2542        -------
2543        predictions : SavedRuns
2544            the predictions meta object
2545        """
2546
2547        try:
2548            return SavedRuns(self._predictions_path(), self.project_path)
2549        except:
2550            self.load_metadata_backup()
2551            return SavedRuns(self._predictions_path(), self.project_path)
2552
2553    def _saved_datasets(self) -> SavedStores:
2554        """
2555        Get the datasets meta object
2556
2557        Returns
2558        -------
2559        datasets : SavedStores
2560            the datasets meta object
2561        """
2562
2563        try:
2564            return SavedStores(self._saved_datasets_path())
2565        except:
2566            self.load_metadata_backup()
2567            return SavedStores(self._saved_datasets_path())
2568
2569    def _prediction(self, name: str) -> Run:
2570        """
2571        Get a prediction meta object
2572
2573        Parameters
2574        ----------
2575        name : str
2576            episode name
2577
2578        Returns
2579        -------
2580        prediction : Run
2581            the prediction meta object
2582        """
2583
2584        try:
2585            return Run(name, self.project_path, meta_path=self._predictions_path())
2586        except:
2587            self.load_metadata_backup()
2588            return Run(name, self.project_path, meta_path=self._predictions_path())
2589
2590    def _episode(self, name: str) -> Run:
2591        """
2592        Get an episode meta object
2593
2594        Parameters
2595        ----------
2596        name : str
2597            episode name
2598
2599        Returns
2600        -------
2601        episode : Run
2602            the episode meta object
2603        """
2604
2605        try:
2606            return Run(name, self.project_path, meta_path=self._episodes_path())
2607        except:
2608            self.load_metadata_backup()
2609            return Run(name, self.project_path, meta_path=self._episodes_path())
2610
2611    def _searches(self) -> Searches:
2612        """
2613        Get the hyperparameter search meta object
2614
2615        Returns
2616        -------
2617        searches : Searches
2618            the searches meta object
2619        """
2620
2621        try:
2622            return Searches(self._searches_path(), self.project_path)
2623        except:
2624            self.load_metadata_backup()
2625            return Searches(self._searches_path(), self.project_path)
2626
2627    def _update_configs(self) -> None:
2628        """
2629        Update the project config files with newly added files and parameters
2630        """
2631
2632        self.update_parameters({"data": {"data_path": self.data_path}})
2633        folders = ["augmentations", "features", "model"]
2634        original_path = os.path.join(
2635            os.path.dirname(os.path.dirname(__file__)), "config"
2636        )
2637        project_path = os.path.join(self.project_path, "config")
2638        filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")]
2639        for folder in folders:
2640            filenames += [
2641                os.path.join(folder, x)
2642                for x in os.listdir(os.path.join(original_path, folder))
2643            ]
2644        filenames.append(os.path.join("data", f"{self.data_type}.yaml"))
2645        if self.annotation_type != "none":
2646            filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml"))
2647        for file in filenames:
2648            filepath_original = os.path.join(original_path, file)
2649            if file.startswith("data") or file.startswith("annotation"):
2650                file = os.path.basename(file)
2651            filepath_project = os.path.join(project_path, file)
2652            if not os.path.exists(filepath_project):
2653                shutil.copy(filepath_original, filepath_project)
2654            else:
2655                original_pars = self._open_yaml(filepath_original)
2656                project_pars = self._open_yaml(filepath_project)
2657                to_remove = []
2658                for key, value in project_pars.items():
2659                    if key not in original_pars:
2660                        if key not in ["data_type", "annotation_type"]:
2661                            to_remove.append(key)
2662                for key in to_remove:
2663                    project_pars.pop(key)
2664                to_remove = []
2665                for key, value in original_pars.items():
2666                    if key in project_pars:
2667                        to_remove.append(key)
2668                for key in to_remove:
2669                    original_pars.pop(key)
2670                project_pars = self._update(project_pars, original_pars)
2671                with open(filepath_project, "w", encoding="utf-8") as f:
2672                    YAML().dump(project_pars, f)
2673
2674    def _update_project(self) -> None:
2675        """
2676        Update project files with the current version
2677        """
2678
2679        version_file = self._version_path()
2680        ok = True
2681        if not os.path.exists(version_file):
2682            ok = False
2683        else:
2684            with open(version_file) as f:
2685                project_version = f.read()
2686            if project_version < __version__:
2687                ok = False
2688            elif project_version > __version__:
2689                warnings.warn(
2690                    f"The project expects a higher dlc2action version ({project_version}), please update!"
2691                )
2692        if not ok:
2693            project_config_path = os.path.join(self.project_path, "config")
2694            config_path = os.path.join(
2695                os.path.dirname(os.path.dirname(__path__)), "config"
2696            )
2697            episodes = self._episodes()
2698            folders = ["annotation", "augmentations", "data", "features", "model"]
2699
2700            project_annotation_configs = os.listdir(
2701                os.path.join(project_config_path, "annotation")
2702            )
2703            annotation_configs = os.listdir(os.path.join(config_path, "annotation"))
2704            for ann_config in annotation_configs:
2705                if ann_config not in project_annotation_configs:
2706                    shutil.copytree(
2707                        os.path.join(config_path, "annotation", ann_config),
2708                        os.path.join(project_config_path, "annotation", ann_config),
2709                        dirs_exist_ok=True,
2710                    )
2711                else:
2712                    project_pars = self._open_yaml(
2713                        os.path.join(project_config_path, "annotation", ann_config)
2714                    )
2715                    pars = self._open_yaml(
2716                        os.path.join(config_path, "annotation", ann_config)
2717                    )
2718                    new_keys = set(pars.keys()) - set(project_pars.keys())
2719                    for key in new_keys:
2720                        project_pars[key] = pars[key]
2721                        c = self._get_comment(pars.ca.items.get(key))
2722                        project_pars.yaml_add_eol_comment(c, key=key)
2723                        episodes.update(
2724                            condition=f"general/annotation_type::={ann_config}",
2725                            update={f"data/{key}": pars[key]},
2726                        )
2727
2728    def _initialize_project(
2729        self,
2730        data_type: str,
2731        annotation_type: str = None,
2732        data_path: str = None,
2733        annotation_path: str = None,
2734        copy: bool = True,
2735    ) -> None:
2736        """
2737        Initialize a new project
2738        """
2739
2740        if data_type not in self.data_types():
2741            raise ValueError(
2742                f"The {data_type} data type is not available. "
2743                f"Please choose from {self.data_types()}"
2744            )
2745        if annotation_type not in self.annotation_types():
2746            raise ValueError(
2747                f"The {annotation_type} annotation type is not available. "
2748                f"Please choose from {self.annotation_types()}"
2749            )
2750        os.mkdir(self.project_path)
2751        folders = ["results", "saved_datasets", "meta", "config"]
2752        for f in folders:
2753            os.mkdir(os.path.join(self.project_path, f))
2754        results_subfolders = [
2755            "model",
2756            "logs",
2757            "predictions",
2758            "splits",
2759            "searches",
2760        ]
2761        for sf in results_subfolders:
2762            os.mkdir(os.path.join(self.project_path, "results", sf))
2763        if data_path is not None:
2764            if copy:
2765                os.mkdir(os.path.join(self.project_path, "data"))
2766                shutil.copytree(
2767                    data_path,
2768                    os.path.join(self.project_path, "data"),
2769                    dirs_exist_ok=True,
2770                )
2771                data_path = os.path.join(self.project_path, "data")
2772        if annotation_path is not None:
2773            if copy:
2774                os.mkdir(os.path.join(self.project_path, "annotation"))
2775                shutil.copytree(
2776                    annotation_path,
2777                    os.path.join(self.project_path, "annotation"),
2778                    dirs_exist_ok=True,
2779                )
2780                annotation_path = os.path.join(self.project_path, "annotation")
2781        self._generate_config(
2782            data_type,
2783            annotation_type,
2784            data_path=data_path,
2785            annotation_path=annotation_path,
2786        )
2787        self._generate_meta()
2788
2789    def _read_types(self) -> Tuple[str, str]:
2790        """
2791        Get data type and annotation type from existing project files
2792        """
2793
2794        config_path = os.path.join(self.project_path, "config", "general.yaml")
2795        with open(config_path) as f:
2796            pars = YAML().load(f)
2797        data_type = pars["data_type"]
2798        annotation_type = pars["annotation_type"]
2799        return annotation_type, data_type
2800
2801    def _read_paths(self) -> Tuple[str, str]:
2802        """
2803        Get data type and annotation type from existing project files
2804        """
2805
2806        config_path = os.path.join(self.project_path, "config", "data.yaml")
2807        with open(config_path) as f:
2808            pars = YAML().load(f)
2809        data_path = pars["data_path"]
2810        annotation_path = pars["annotation_path"]
2811        return annotation_path, data_path
2812
2813    def _generate_config(
2814        self, data_type: str, annotation_type: str, data_path: str, annotation_path: str
2815    ) -> None:
2816        """
2817        Initialize the config files
2818        """
2819
2820        default_path = os.path.join(
2821            os.path.dirname(os.path.dirname(__file__)), "config"
2822        )
2823        config_path = os.path.join(self.project_path, "config")
2824        files = ["losses", "metrics", "ssl", "training"]
2825        for f in files:
2826            shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path)
2827        shutil.copytree(
2828            os.path.join(default_path, "model"), os.path.join(config_path, "model")
2829        )
2830        shutil.copytree(
2831            os.path.join(default_path, "features"),
2832            os.path.join(config_path, "features"),
2833        )
2834        shutil.copytree(
2835            os.path.join(default_path, "augmentations"),
2836            os.path.join(config_path, "augmentations"),
2837        )
2838        yaml = YAML()
2839        data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml")
2840        if os.path.exists(data_param_path):
2841            with open(data_param_path, encoding="utf-8") as f:
2842                data_params = yaml.load(f)
2843        if data_params is None:
2844            data_params = {}
2845        if annotation_type is None:
2846            ann_params = {}
2847        else:
2848            ann_param_path = os.path.join(
2849                default_path, "annotation", f"{annotation_type}.yaml"
2850            )
2851            if os.path.exists(ann_param_path):
2852                ann_params = self._open_yaml(ann_param_path)
2853            elif annotation_type == "none":
2854                ann_params = {}
2855            else:
2856                raise ValueError(
2857                    f"The {annotation_type} data type is not available. "
2858                    f"Please choose from {BehaviorDataset.annotation_types()}"
2859                )
2860        if ann_params is None:
2861            ann_params = {}
2862        data_params = self._update(data_params, ann_params)
2863        data_params["data_path"] = data_path
2864        data_params["annotation_path"] = annotation_path
2865        with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f:
2866            yaml.dump(data_params, f)
2867        with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f:
2868            general_params = yaml.load(f)
2869        general_params["data_type"] = data_type
2870        general_params["annotation_type"] = annotation_type
2871        with open(os.path.join(config_path, "general.yaml"), "w", encoding="utf-8") as f:
2872            yaml.dump(general_params, f)
2873
2874    def _generate_meta(self) -> None:
2875        """
2876        Initialize the meta files
2877        """
2878
2879        config_file = os.path.join(self.project_path, "config")
2880        meta_fields = ["time"]
2881        columns = [("meta", field) for field in meta_fields]
2882        episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
2883        episodes.to_pickle(self._episodes_path())
2884        meta_fields = ["time", "objective"]
2885        result_fields = ["best_params", "best_value"]
2886        columns = [("meta", field) for field in meta_fields] + [
2887            ("results", field) for field in result_fields
2888        ]
2889        searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
2890        searches.to_pickle(self._searches_path())
2891        meta_fields = ["time"]
2892        columns = [("meta", field) for field in meta_fields]
2893        predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
2894        predictions.to_pickle(self._predictions_path())
2895        with open(os.path.join(config_file, "data.yaml")) as f:
2896            data_keys = list(YAML().load(f).keys())
2897        saved_data = pd.DataFrame(columns=data_keys)
2898        saved_data.to_pickle(self._saved_datasets_path())
2899        pd.DataFrame().to_pickle(self._thresholds_path())
2900        # with open(self._version_path()) as f:
2901        #     f.write(__version__)
2902
2903    def _open_yaml(self, path: str) -> CommentedMap:
2904        """
2905        Load a parameter dictionary from a .yaml file
2906        """
2907
2908        with open(path, encoding="utf-8") as f:
2909            data = YAML().load(f)
2910        if data is None:
2911            data = {}
2912        return data
2913
2914    def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7):
2915        """
2916        Compare nested dictionaries with 'almost equal' condition
2917        """
2918
2919        ok = True
2920        if u.keys() != d.keys():
2921            ok = False
2922        else:
2923            for k, v in u.items():
2924                if isinstance(v, Mapping):
2925                    ok = self._compare(d[k], v, allow_diff=allow_diff)
2926                else:
2927                    if isinstance(v, float) or isinstance(d[k], float):
2928                        if not isinstance(d[k], float) and not isinstance(d[k], int):
2929                            ok = False
2930                        elif not isinstance(v, float) and not isinstance(v, int):
2931                            ok = False
2932                        elif np.abs(v - d[k]) > allow_diff:
2933                            ok = False
2934                    elif v != d[k]:
2935                        ok = False
2936        return ok
2937
2938    def _check_comment(self, comment_sequence: List) -> bool:
2939        """
2940        Check if a comment already exists in a ruamel.yaml comment sequence
2941        """
2942
2943        if comment_sequence is None:
2944            return False
2945        c = self._get_comment(comment_sequence)
2946        if c != "":
2947            return True
2948        else:
2949            return False
2950
2951    def _get_comment(self, comment_sequence: List, strip=True) -> str:
2952        """
2953        Get the comment string from a ruamel.yaml comment sequence
2954        """
2955
2956        if comment_sequence is None:
2957            return ""
2958        c = ""
2959        for cm in comment_sequence:
2960            if cm is not None:
2961                if isinstance(cm, Iterable):
2962                    for c in cm:
2963                        if c is not None:
2964                            c = c.value
2965                            break
2966                    break
2967                else:
2968                    c = cm.value
2969                    break
2970        if strip:
2971            c = c.strip()
2972        return c
2973
2974    def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]):
2975        """
2976        Update a nested dictionary
2977        """
2978
2979        if "general" in u and "model_name" in u["general"] and "model" in d:
2980            model_name = u["general"]["model_name"]
2981            if d["general"]["model_name"] != model_name:
2982                d["model"] = self._open_yaml(
2983                    os.path.join(
2984                        self.project_path, "config", "model", f"{model_name}.yaml"
2985                    )
2986                )
2987        d_copied = deepcopy(d)
2988        for k, v in u.items():
2989            if (
2990                k in d_copied
2991                and isinstance(d_copied[k], list)
2992                and isinstance(v, Mapping)
2993                and all([isinstance(x, int) for x in v.keys()])
2994            ):
2995                for kk, vv in v.items():
2996                    d_copied[k][kk] = vv
2997            elif (
2998                isinstance(v, Mapping)
2999                and k in d_copied
3000                and isinstance(d_copied[k], Mapping)
3001            ):
3002                if d_copied[k] is None:
3003                    d_k = CommentedMap()
3004                else:
3005                    d_k = d_copied[k]
3006                d_copied[k] = self._update(d_k, v)
3007            else:
3008                d_copied[k] = v
3009                if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None:
3010                    c = self._get_comment(u.ca.items.get(k), strip=False)
3011                    if isinstance(d_copied, CommentedMap) and not self._check_comment(
3012                        d_copied.ca.items.get(k)
3013                    ):
3014                        d_copied.yaml_add_eol_comment(c, key=k)
3015        return d_copied
3016
3017    def _update_with_search(
3018        self,
3019        d: Dict,
3020        search_name: str,
3021        load_parameters: list = None,
3022        round_to_binary: list = None,
3023    ):
3024        """
3025        Update a dictionary with best parameters from a hyperparameter search
3026        """
3027
3028        u, _ = self._searches().get_best_params(
3029            search_name, load_parameters, round_to_binary
3030        )
3031        return self._update(d, u)
3032
3033    def _read_parameters(self, catch_blanks=True) -> Dict:
3034        """
3035        Compose a parameter dictionary to create a task from the config files
3036        """
3037
3038        config_path = os.path.join(self.project_path, "config")
3039        keys = [
3040            "data",
3041            "general",
3042            "losses",
3043            "metrics",
3044            "ssl",
3045            "training",
3046        ]
3047        parameters = {}
3048        for key in keys:
3049            parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml"))
3050        features = parameters["general"]["feature_extraction"]
3051        parameters["features"] = self._open_yaml(
3052            os.path.join(config_path, "features", f"{features}.yaml")
3053        )
3054        transformer = options.extractor_to_transformer[features]
3055        parameters["augmentations"] = self._open_yaml(
3056            os.path.join(config_path, "augmentations", f"{transformer}.yaml")
3057        )
3058        model = parameters["general"]["model_name"]
3059        parameters["model"] = self._open_yaml(
3060            os.path.join(config_path, "model", f"{model}.yaml")
3061        )
3062        # input = parameters["general"]["input"]
3063        # parameters["model"] = self._open_yaml(
3064        #     os.path.join(config_path, "model", f"{model}.yaml")
3065        # )
3066        if catch_blanks:
3067            blanks = self._get_blanks()
3068            if len(blanks) > 0:
3069                self.list_blanks()
3070                raise ValueError(
3071                    f"Please fill in all the blanks before running experiments"
3072                )
3073        return parameters
3074
3075    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3076        """
3077        Select the model and the metrics
3078
3079        Parameters
3080        ----------
3081        model_name : str, optional
3082            model name; run `project.help("model") to find out more
3083        metric_names : list, optional
3084            a list of metric function names; run `project.help("metrics") to find out more
3085        """
3086
3087        pars = {"general": {}}
3088        if model_name is not None:
3089            assert model_name in options.models
3090            pars["general"]["model_name"] = model_name
3091        if metric_names is not None:
3092            for metric in metric_names:
3093                assert metric in options.metrics
3094            pars["general"]["metric_functions"] = metric_names
3095        self.update_parameters(pars)
3096
3097    def help(self, keyword: str = None):
3098        """
3099        Get information on available options
3100
3101        Parameters
3102        ----------
3103        keyword : str, optional
3104            the keyword for options (run without arguments to see which keywords are available)
3105
3106        """
3107
3108        if keyword is None:
3109            print("AVAILABLE HELP FUNCTIONS:")
3110            print("- Try running `project.help(keyword)` with the following keywords:")
3111            print("    - model: to get more information on available models,")
3112            print(
3113                "    - features: to get more information on available feature extraction modes,"
3114            )
3115            print(
3116                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3117            )
3118            print("    - metrics: to see a list of available metric functions.")
3119            print("    - data: to see help for expected data structure")
3120            print(
3121                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
3122            )
3123            print(
3124                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
3125            )
3126            print(
3127                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
3128            )
3129        elif keyword == "model":
3130            print("MODELS:")
3131            for key, model in options.models.items():
3132                print(f"{key}:")
3133                print(model.__doc__)
3134        elif keyword == "features":
3135            print("FEATURE EXTRACTORS:")
3136            for key, extractor in options.feature_extractors.items():
3137                print(f"{key}:")
3138                print(extractor.__doc__)
3139        elif keyword == "partition_method":
3140            print("PARTITION METHODS:")
3141            print(
3142                BehaviorDataset.partition_train_test_val.__doc__.split(
3143                    "The partitioning method:"
3144                )[1].split("val_frac :")[0]
3145            )
3146        elif keyword == "metrics":
3147            print("METRICS:")
3148            for key, metric in options.metrics.items():
3149                print(f"{key}:")
3150                print(metric.__doc__)
3151        elif keyword == "data":
3152            print("DATA:")
3153            print(f"Video data: {self.data_type}")
3154            print(options.input_stores[self.data_type].__doc__)
3155            print(f"Annotation data: {self.annotation_type}")
3156            print(options.annotation_stores[self.annotation_type].__doc__)
3157            print(
3158                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
3159            )
3160        else:
3161            raise ValueError(f"The {keyword} keyword is not recognized")
3162        print("\n")
3163
3164    def _process_value(self, value):
3165        if isinstance(value, str):
3166            value = f'"{value}"'
3167        elif isinstance(value, CommentedSet):
3168            value = {x for x in value}
3169        return value
3170
3171    def _get_blanks(self):
3172        caught = []
3173        parameters = self._read_parameters(catch_blanks=False)
3174        for big_key, big_value in parameters.items():
3175            for key, value in big_value.items():
3176                if value == "???":
3177                    caught.append(
3178                        (big_key, key, self._get_comment(big_value.ca.items.get(key)))
3179                    )
3180        return caught
3181
3182    def list_blanks(self, blanks=None):
3183        """
3184        List parameters that need to be filled in
3185
3186        Parameters
3187        ----------
3188        blanks : list, optional
3189            a list of the parameters to list, if already known
3190        """
3191
3192        if blanks is None:
3193            blanks = self._get_blanks()
3194        if len(blanks) > 0:
3195            to_update = defaultdict(lambda: [])
3196            for b, k, c in blanks:
3197                to_update[b].append((k, c))
3198            print("Before running experiments, please update all the blanks.")
3199            print("To do that, you can run this.")
3200            print("--------------------------------------------------------")
3201            print(f"project.update_parameters(")
3202            print(f"    {{")
3203            for big_key, keys in to_update.items():
3204                print(f'        "{big_key}": {{')
3205                for key, comment in keys:
3206                    print(f'            "{key}": ..., {comment}')
3207                print(f"        }},")
3208            print(f"    }}")
3209            print(")")
3210            print("--------------------------------------------------------")
3211            print("Replace ... with relevant values.")
3212        else:
3213            print("There is no blanks left!")
3214
3215    def list_basic_parameters(
3216        self,
3217    ):
3218        """
3219        Get a list of most relevant parameters and code to modify them
3220        """
3221
3222        parameters = self._read_parameters()
3223        print("BASIC PARAMETERS:")
3224        model_name = parameters["general"]["model_name"]
3225        metric_names = parameters["general"]["metric_functions"]
3226        loss_name = parameters["general"]["loss_function"]
3227        feature_extraction = parameters["general"]["feature_extraction"]
3228        print("Here is a list of current parameters.")
3229        print(
3230            "You can copy this code, change the parameters you want to set and run it to update the project config."
3231        )
3232        print("--------------------------------------------------------")
3233        print("project.update_parameters(")
3234        print("    {")
3235        for group in ["general", "data", "training"]:
3236            print(f'        "{group}": {{')
3237            for key in options.basic_parameters[group]:
3238                if key in parameters[group]:
3239                    print(
3240                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
3241                    )
3242            print("        },")
3243        print('        "losses": {')
3244        print(f'            "{loss_name}": {{')
3245        for key in options.basic_parameters["losses"][loss_name]:
3246            if key in parameters["losses"][loss_name]:
3247                print(
3248                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
3249                )
3250        print("            },")
3251        print("        },")
3252        print('        "metrics": {')
3253        for metric in metric_names:
3254            print(f'            "{metric}": {{')
3255            for key in parameters["metrics"][metric]:
3256                print(
3257                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
3258                )
3259            print("            },")
3260        print("        },")
3261        print('        "model": {')
3262        for key in options.basic_parameters["model"][model_name]:
3263            if key in parameters["model"]:
3264                print(
3265                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
3266                )
3267
3268        print("        },")
3269        print('        "features": {')
3270        for key in options.basic_parameters["features"][feature_extraction]:
3271            if key in parameters["features"]:
3272                print(
3273                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
3274                )
3275
3276        print("        },")
3277        print('        "augmentations": {')
3278        for key in options.basic_parameters["augmentations"][feature_extraction]:
3279            if key in parameters["augmentations"]:
3280                print(
3281                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
3282                )
3283        print("        },")
3284        print("    },")
3285        print(")")
3286        print("--------------------------------------------------------")
3287        print("\n")
3288
3289    def _create_record(
3290        self,
3291        episode_name: str,
3292        behaviors_dict: Dict,
3293        load_episode: str = None,
3294        parameters_update: Dict = None,
3295        task: TaskDispatcher = None,
3296        load_epoch: int = None,
3297        load_search: str = None,
3298        load_parameters: list = None,
3299        round_to_binary: list = None,
3300        load_strict: bool = True,
3301        n_seeds: int = 1,
3302    ) -> TaskDispatcher:
3303        """
3304        Create a meta data episode record
3305        """
3306
3307        if episode_name in self._episodes().data.index:
3308            return
3309        if type(n_seeds) is not int or n_seeds < 1:
3310            raise ValueError(
3311                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
3312            )
3313        if parameters_update is None:
3314            parameters_update = {}
3315        parameters = self._read_parameters()
3316        parameters = self._update(parameters, parameters_update)
3317        if load_search is not None:
3318            parameters = self._update_with_search(
3319                parameters, load_search, load_parameters, round_to_binary
3320            )
3321        parameters = self._fill(
3322            parameters,
3323            episode_name,
3324            load_episode,
3325            load_epoch=load_epoch,
3326            only_load_model=True,
3327            load_strict=load_strict,
3328            continuing=True,
3329        )
3330        self._save_episode(episode_name, parameters, behaviors_dict)
3331        return task
3332
3333    def _save_thresholds(
3334        self,
3335        episode_names: List,
3336        metric_name: str,
3337        parameters: Dict,
3338        thresholds: List,
3339        load_epochs: List,
3340    ):
3341        """
3342        Save optimal decision thresholds in the meta records
3343        """
3344
3345        metric_parameters = parameters["metrics"][metric_name]
3346        self._thresholds().save_thresholds(
3347            episode_names, load_epochs, metric_name, metric_parameters, thresholds
3348        )
3349
3350    def _save_episode(
3351        self,
3352        episode_name: str,
3353        parameters: Dict,
3354        behaviors_dict: Dict,
3355        suppress_validation: bool = False,
3356        training_time: str = None,
3357        norm_stats: Dict = None,
3358    ) -> None:
3359        """
3360        Save an episode in the meta files
3361        """
3362
3363        try:
3364            split_info = self._split_info_from_filename(
3365                parameters["training"]["split_path"]
3366            )
3367            parameters["training"]["partition_method"] = split_info["partition_method"]
3368        except:
3369            pass
3370        if norm_stats is not None:
3371            norm_stats = dict(norm_stats)
3372        parameters["training"]["stats"] = norm_stats
3373        self._episodes().save_episode(
3374            episode_name,
3375            parameters,
3376            behaviors_dict,
3377            suppress_validation=suppress_validation,
3378            training_time=training_time,
3379        )
3380
3381    def _update_episode_results(
3382        self,
3383        episode_name: str,
3384        logs: Tuple,
3385        training_time: str = None,
3386    ) -> None:
3387        """
3388        Save the results of a run in the meta files
3389        """
3390
3391        self._episodes().update_episode_results(episode_name, logs, training_time)
3392
3393    def _save_prediction(
3394        self,
3395        episode_name: str,
3396        parameters: Dict,
3397        behaviors_dict: Dict,
3398        embedding: bool = False,
3399        inference_time: str = None,
3400    ) -> None:
3401        """
3402        Save a prediction in the meta files
3403        """
3404
3405        parameters = self._update(
3406            parameters,
3407            {"meta": {"embedding": embedding, "inference_time": inference_time}},
3408        )
3409        self._predictions().save_episode(episode_name, parameters, behaviors_dict)
3410
3411    def _save_search(
3412        self,
3413        search_name: str,
3414        parameters: Dict,
3415        n_trials: int,
3416        best_params: Dict,
3417        best_value: float,
3418        metric: str,
3419        search_space: Dict,
3420    ) -> None:
3421        """
3422        Save a hyperparameter search in the meta files
3423        """
3424
3425        self._searches().save_search(
3426            search_name,
3427            parameters,
3428            n_trials,
3429            best_params,
3430            best_value,
3431            metric,
3432            search_space,
3433        )
3434
3435    def _save_stores(self, parameters: Dict) -> None:
3436        """
3437        Save a pickled dataset in the meta files
3438        """
3439
3440        name = os.path.basename(parameters["data"]["feature_save_path"])
3441        self._saved_datasets().save_store(name, self._get_data_pars(parameters))
3442        self.create_metadata_backup()
3443
3444    def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None:
3445        """
3446        Remove the pre-computed features folder
3447        """
3448
3449        name = os.path.basename(parameters["data"]["feature_save_path"])
3450        if remove_active or name not in self._episodes().get_active_datasets():
3451            self.remove_saved_features([name])
3452
3453    def _check_episode_validity(
3454        self, episode_name: str, allow_doublecolon: bool = False, force: bool = False
3455    ) -> None:
3456        """
3457        Check whether the episode name is valid
3458        """
3459
3460        if episode_name.startswith("_"):
3461            raise ValueError(
3462                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
3463            )
3464        elif "." in episode_name:
3465            raise ValueError("Names containing '.' cannot be used!")
3466        if not allow_doublecolon and "::" in episode_name:
3467            raise ValueError(
3468                "Names containing '::' are reserved by dlc2action and cannot be used!"
3469            )
3470        if force:
3471            self.remove_episode(episode_name)
3472        elif not self._episodes().check_name_validity(episode_name):
3473            raise ValueError(
3474                f"The {episode_name} name is already taken! Set force=True to overwrite."
3475            )
3476
3477    def _check_search_validity(self, search_name: str, force: bool = False) -> None:
3478        """
3479        Check whether the search name is valid
3480        """
3481
3482        if search_name.startswith("_"):
3483            raise ValueError(
3484                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
3485            )
3486        elif "." in search_name:
3487            raise ValueError("Names containing '.' cannot be used!")
3488        if force:
3489            self.remove_search(search_name)
3490        elif not self._searches().check_name_validity(search_name):
3491            raise ValueError(f"The {search_name} name is already taken!")
3492
3493    def _check_prediction_validity(
3494        self, prediction_name: str, force: bool = False
3495    ) -> None:
3496        """
3497        Check whether the prediction name is valid
3498        """
3499
3500        if prediction_name.startswith("_"):
3501            raise ValueError(
3502                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
3503            )
3504        elif "." in prediction_name:
3505            raise ValueError("Names containing '.' cannot be used!")
3506        if force:
3507            self.remove_prediction(prediction_name)
3508        elif not self._predictions().check_name_validity(prediction_name):
3509            raise ValueError(f"The {prediction_name} name is already taken!")
3510
3511    def _training_time(self, episode_name: str) -> int:
3512        """
3513        Get the training time of an episode in seconds
3514        """
3515
3516        return self._episode(episode_name).training_time()
3517
3518    def _mask_path(self) -> str:
3519        """
3520        Get the path to the masks folder
3521        """
3522
3523        return os.path.join(self.project_path, "results", "masks")
3524
3525    def _thresholds_path(self) -> str:
3526        """
3527        Get the path to the thresholds meta file
3528        """
3529
3530        return os.path.join(self.project_path, "meta", "thresholds.pickle")
3531
3532    def _episodes_path(self) -> str:
3533        """
3534        Get the path to the episodes meta file
3535        """
3536
3537        return os.path.join(self.project_path, "meta", "episodes.pickle")
3538
3539    def _saved_datasets_path(self) -> str:
3540        """
3541        Get the path to the datasets meta file
3542        """
3543
3544        return os.path.join(self.project_path, "meta", "saved_datasets.pickle")
3545
3546    def _predictions_path(self) -> str:
3547        """
3548        Get the path to the predictions meta file
3549        """
3550
3551        return os.path.join(self.project_path, "meta", "predictions.pickle")
3552
3553    def _dataset_store_path(self, name: str) -> str:
3554        """
3555        Get the path to a specific pickled dataset
3556        """
3557
3558        return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle")
3559
3560    def _searches_path(self) -> str:
3561        """
3562        Get the path to the hyperparameter search meta file
3563        """
3564
3565        return os.path.join(self.project_path, "meta", "searches.pickle")
3566
3567    def _search_path(self, name: str) -> str:
3568        """
3569        Get the default path to the graph folder for a specific hyperparameter search
3570        """
3571
3572        return os.path.join(self.project_path, "results", "searches", name)
3573
3574    def _version_path(self) -> str:
3575        """
3576        Get the path to the version file
3577        """
3578
3579        return os.path.join(self.project_path, "meta", "version.txt")
3580
3581    def _default_split_file(self, split_info: Dict) -> Optional[str]:
3582        """
3583        Generate a path to a split file from split parameters
3584        """
3585
3586        if split_info["partition_method"].startswith("time"):
3587            return None
3588        val_frac = split_info["val_frac"]
3589        test_frac = split_info["test_frac"]
3590        split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}'
3591        if not split_info["only_load_annotated"]:
3592            split_name += "_all"
3593        split_name += ".txt"
3594        return os.path.join(self.project_path, "results", "splits", split_name)
3595
3596    def _split_info_from_filename(self, split_name: str) -> Dict:
3597        """
3598        Get split parameters from default path to a split file
3599        """
3600
3601        if split_name is None:
3602            return {}
3603        try:
3604            name = os.path.basename(split_name)[:-4]
3605            split = name.split("_")
3606            if len(split) == 6:
3607                only_load_annotated = False
3608            else:
3609                only_load_annotated = True
3610            len_segment = int(split[3][3:])
3611            overlap = int(split[4][7:])
3612            method, val, test = split[:3]
3613            val = float(val[3:-1]) / 100
3614            test = float(test[4:-1]) / 100
3615            return {
3616                "partition_method": method,
3617                "val_frac": val,
3618                "test_frac": test,
3619                "only_load_annotated": only_load_annotated,
3620                "len_segment": len_segment,
3621                "overlap": overlap,
3622            }
3623        except:
3624            return {"partition_method": "file"}
3625
3626    def _fill(
3627        self,
3628        parameters: Dict,
3629        episode_name: str,
3630        load_experiment: str = None,
3631        load_epoch: int = None,
3632        load_strict: bool = True,
3633        only_load_model: bool = False,
3634        continuing: bool = False,
3635        enforce_split_parameters: bool = False,
3636    ) -> Dict:
3637        """
3638        Update the parameters from the config files with project specific information
3639
3640        Fill in the constant file path parameters and generate a unique log file and a model folder.
3641        Fill in the split file if the same split has been run before in the project and change partition method to
3642        from_file.
3643        Fill in saved data path if a dataset with the same data parameters already exists in the project.
3644        If load_experiment is not None, fill in the checkpoint path as well.
3645        The only_load_model training parameter is defined by the corresponding argument.
3646        If continuing is True, new files are not created and all information is loaded from load_experiment.
3647        If prediction is True, log and model files are not created.
3648        The enforce_split_parameters parameter is used to resolve conflicts
3649        between split file path and split parameters when they arise.
3650        """
3651
3652        pars = deepcopy(parameters)
3653        if episode_name == "_":
3654            self.remove_episode("_")
3655        log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt")
3656        model_save_path = os.path.join(
3657            self.project_path, "results", "model", episode_name
3658        )
3659        if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)):
3660            raise ValueError(
3661                f"The {episode_name} episode name is already in use! Set force=True to overwrite."
3662            )
3663        keys = ["val_frac", "test_frac", "partition_method"]
3664        if "len_segment" not in pars["general"] and "len_segment" in pars["data"]:
3665            pars["general"]["len_segment"] = pars["data"]["len_segment"]
3666        if "overlap" not in pars["general"] and "overlap" in pars["data"]:
3667            pars["general"]["overlap"] = pars["data"]["overlap"]
3668        if "len_segment" in pars["data"]:
3669            pars["data"].pop("len_segment")
3670        if "overlap" in pars["data"]:
3671            pars["data"].pop("overlap")
3672        split_info = {k: pars["training"][k] for k in keys}
3673        split_info["only_load_annotated"] = pars["general"]["only_load_annotated"]
3674        split_info["len_segment"] = pars["general"]["len_segment"]
3675        split_info["overlap"] = pars["general"]["overlap"]
3676        pars["training"]["log_file"] = log
3677        if not os.path.exists(model_save_path):
3678            os.mkdir(model_save_path)
3679        pars["training"]["model_save_path"] = model_save_path
3680        if load_experiment is not None:
3681            if load_experiment not in self._episodes().data.index:
3682                raise ValueError(f"The {load_experiment} episode does not exist!")
3683            old_episode = self._episode(load_experiment)
3684            old_file = old_episode.split_file()
3685            old_info = self._split_info_from_filename(old_file)
3686            if len(old_info) == 0:
3687                old_info = old_episode.split_info()
3688            if enforce_split_parameters:
3689                if split_info["partition_method"] != "file":
3690                    pars["training"]["split_path"] = self._default_split_file(
3691                        split_info
3692                    )
3693            else:
3694                equal = True
3695                if old_info["partition_method"] != split_info["partition_method"]:
3696                    equal = False
3697                if old_info["partition_method"] != "file":
3698                    if (
3699                        old_info["val_frac"] != split_info["val_frac"]
3700                        or old_info["test_frac"] != split_info["test_frac"]
3701                    ):
3702                        equal = False
3703                if not continuing and not equal:
3704                    warnings.warn(
3705                        f"The partitioning parameters in the loaded experiment ({old_info}) "
3706                        f"are not equal to the current partitioning parameters ({split_info}). "
3707                        f"The current parameters are replaced."
3708                    )
3709                pars["training"]["split_path"] = old_file
3710            pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch)
3711            pars["training"]["load_strict"] = load_strict
3712        else:
3713            pars["training"]["checkpoint_path"] = None
3714            if pars["training"]["partition_method"] == "file":
3715                if (
3716                    "split_path" not in pars["training"]
3717                    or pars["training"]["split_path"] is None
3718                ):
3719                    raise ValueError(
3720                        "The partition_method parameter is set to file but the "
3721                        "split_path parameter is not set!"
3722                    )
3723                elif not os.path.exists(pars["training"]["split_path"]):
3724                    raise ValueError(
3725                        f'The {pars["training"]["split_path"]} split file does not exist'
3726                    )
3727            else:
3728                pars["training"]["split_path"] = self._default_split_file(split_info)
3729        pars["training"]["only_load_model"] = only_load_model
3730        pars["data"]["saved_data_path"] = None
3731        pars["data"]["feature_save_path"] = None
3732        pars_data_copy = self._get_data_pars(pars)
3733        saved_data_name = self._saved_datasets().find_name(pars_data_copy)
3734        if saved_data_name is not None:
3735            pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name)
3736            pars["data"]["feature_save_path"] = self._dataset_store_path(
3737                saved_data_name
3738            ).split(".")[0]
3739        else:
3740            dataset_path = self._dataset_store_path(episode_name)
3741            if os.path.exists(dataset_path):
3742                name, ext = dataset_path.split(".")
3743                i = 0
3744                while os.path.exists(f"{name}_{i}.{ext}"):
3745                    i += 1
3746                dataset_path = f"{name}_{i}.{ext}"
3747            pars["data"]["saved_data_path"] = dataset_path
3748            pars["data"]["feature_save_path"] = dataset_path.split(".")[0]
3749        split_split = pars["training"]["partition_method"].split(":")
3750        random = True
3751        for partition_method in options.partition_methods["fixed"]:
3752            method_split = partition_method.split(":")
3753            if len(split_split) != len(method_split):
3754                continue
3755            equal = True
3756            for x, y in zip(split_split, method_split):
3757                if y.startswith("{"):
3758                    continue
3759                if x != y:
3760                    equal = False
3761                    break
3762            if equal:
3763                random = False
3764                break
3765        if random and os.path.exists(pars["training"]["split_path"]):
3766            pars["training"]["partition_method"] = "file"
3767        pars["general"]["save_dataset"] = True
3768        return pars
3769
3770    def _get_data_pars(self, pars: Dict) -> Dict:
3771        """
3772        Get a complete description of the data from a general parameters dictionary
3773        """
3774
3775        pars_data_copy = deepcopy(pars["data"])
3776        for par in [
3777            "only_load_annotated",
3778            "exclusive",
3779            "feature_extraction",
3780            "ignored_clips",
3781            "len_segment",
3782            "overlap",
3783        ]:
3784            pars_data_copy[par] = pars["general"].get(par, None)
3785        pars_data_copy.update(pars["features"])
3786        return pars_data_copy
3787
3788    def count_classes(
3789        self,
3790        load_episode: str = None,
3791        parameters_update: Dict = None,
3792        remove_saved_features: bool = False,
3793        bouts: bool = True,
3794    ) -> Dict:
3795        """
3796        Get a dictionary of class counts in different modes
3797
3798        Parameters
3799        ----------
3800        load_episode : str, optional
3801            the episode settings to load
3802        parameters_update : dict, optional
3803            a dictionary of parameter updates (only for "data" and "general" categories)
3804        remove_saved_features : bool, default False
3805            if `True`, the dataset that is used for computation is then deleted
3806        bouts : bool, default False
3807            if `True`, instead of frame counts segment counts are returned
3808
3809        Returns
3810        -------
3811        class_counts : dict
3812            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
3813            class names and values are class counts (in frames)
3814        """
3815
3816        if load_episode is None:
3817            task, parameters = self._make_task_training(
3818                episode_name="_", parameters_update=parameters_update, throwaway=True
3819            )
3820        else:
3821            task, parameters, _ = self._make_task_prediction(
3822                "_",
3823                load_episode=load_episode,
3824                parameters_update=parameters_update,
3825            )
3826        class_counts = task.count_classes(bouts=bouts)
3827        behaviors = task.behaviors_dict()
3828        class_counts = {
3829            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
3830            for kk, vv in class_counts.items()
3831        }
3832        if remove_saved_features:
3833            self._remove_stores(parameters)
3834        return class_counts
3835
3836    def plot_class_distribution(
3837        self,
3838        parameters_update: Dict = None,
3839        frame_cutoff: int = 1,
3840        bout_cutoff: int = 1,
3841        print_full: bool = False,
3842        remove_saved_features: bool = False,
3843    ) -> None:
3844        """
3845        Make a class distribution plot
3846
3847        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
3848        is created or laoded for the computation with the default parameters).
3849
3850        Parameters
3851        ----------
3852        parameters_update : dict, optional
3853            a dictionary of parameter updates (only for "data" and "general" categories)
3854        remove_saved_features : bool, default False
3855            if `True`, the dataset that is used for computation is then deleted
3856        """
3857
3858        task, parameters = self._make_task_training(
3859            episode_name="_", parameters_update=parameters_update, throwaway=True
3860        )
3861        cutoff = {True: bout_cutoff, False: frame_cutoff}
3862        for bouts in [True, False]:
3863            class_counts = task.count_classes(bouts=bouts)
3864            if print_full:
3865                print("Bouts:" if bouts else "Frames:")
3866                for k, v in class_counts.items():
3867                    if sum(v.values()) != 0:
3868                        print(f"  {k}:")
3869                        values, keys = zip(
3870                            *[
3871                                x
3872                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
3873                                if x[-1] != -100
3874                            ]
3875                        )
3876                        for kk, vv in zip(keys, values):
3877                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
3878            class_counts = {
3879                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
3880                for kk, vv in class_counts.items()
3881            }
3882            for key, d in class_counts.items():
3883                if sum(d.values()) != 0:
3884                    values, keys = zip(
3885                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
3886                    )
3887                    keys = [task.behaviors_dict()[x] for x in keys]
3888                    plt.bar(keys, values)
3889                    plt.title(key)
3890                    plt.xticks(rotation=45, ha="right")
3891                    if bouts:
3892                        plt.ylabel("bouts")
3893                    else:
3894                        plt.ylabel("frames")
3895                    plt.tight_layout()
3896                    plt.show()
3897        if remove_saved_features:
3898            self._remove_stores(parameters)
3899
3900    def _generate_mask(
3901        self,
3902        mask_name: str,
3903        perc_annotated: float = 0.1,
3904        parameters_update: Dict = None,
3905        remove_saved_features: bool = False,
3906    ) -> None:
3907        """
3908        Generate a real_lens for active learning simulation
3909
3910        Parameters
3911        ----------
3912        mask_name : str
3913            the name of the real_lens
3914        """
3915
3916        print(f"GENERATING {mask_name}")
3917        task, parameters = self._make_task_training(
3918            f"_{mask_name}", parameters_update=parameters_update, throwaway=True
3919        )
3920        val_intervals, val_ids = task.dataset("val").get_intervals()  # 1
3921        unannotated_intervals = task.dataset("train").get_unannotated_intervals()  # 2
3922        unannotated_intervals = task.dataset("val").get_unannotated_intervals(
3923            first_intervals=unannotated_intervals
3924        )
3925        ids = task.dataset("train").get_ids()
3926        mask = {video_id: {} for video_id in ids}
3927        total_all = 0
3928        total_masked = 0
3929        for video_id, clip_ids in ids.items():
3930            for clip_id in clip_ids:
3931                frames = np.ones(task.dataset("train").get_len(video_id, clip_id))
3932                if clip_id in val_intervals[video_id]:
3933                    for start, end in val_intervals[video_id][clip_id]:
3934                        frames[start:end] = 0
3935                if clip_id in unannotated_intervals[video_id]:
3936                    for start, end in unannotated_intervals[video_id][clip_id]:
3937                        frames[start:end] = 0
3938                annotated = np.where(frames)[0]
3939                total_all += len(annotated)
3940                masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :]
3941                total_masked += len(masked)
3942                mask[video_id][clip_id] = self._get_intervals(masked)
3943        file = {
3944            "masked": mask,
3945            "val_intervals": val_intervals,
3946            "val_ids": val_ids,
3947            "unannotated": unannotated_intervals,
3948        }
3949        self._save_mask(file, mask_name)
3950        if remove_saved_features:
3951            self._remove_stores(parameters)
3952        print("\n")
3953        # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames')
3954
3955    def _get_intervals(self, frame_indices: np.ndarray):
3956        """
3957        Get a list of intervals from a list of frame indices
3958
3959        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
3960
3961        Parameters
3962        ----------
3963        frame_indices : np.ndarray
3964            a list of frame indices
3965
3966        Returns
3967        -------
3968        intervals : list
3969            a list of interval boundaries
3970        """
3971
3972        masked_intervals = []
3973        if len(frame_indices) > 0:
3974            breaks = np.where(np.diff(frame_indices) != 1)[0]
3975            start = frame_indices[0]
3976            for k in breaks:
3977                masked_intervals.append([start, frame_indices[k] + 1])
3978                start = frame_indices[k + 1]
3979            masked_intervals.append([start, frame_indices[-1] + 1])
3980        return masked_intervals
3981
3982    def _update_mask_with_uncertainty(
3983        self,
3984        mask_name: str,
3985        episode_name: Union[str, None],
3986        classes: List,
3987        load_epoch: int = None,
3988        n_frames: int = 10000,
3989        method: str = "least_confidence",
3990        min_length: int = 30,
3991        augment_n: int = 0,
3992        parameters_update: Dict = None,
3993    ):
3994        """
3995        Update real_lens with frame-wise uncertainty scores for active learning
3996
3997        Parameters
3998        ----------
3999        mask_name : str
4000            the name of the real_lens
4001        episode_name : str
4002            the name of the episode to load
4003        classes : list
4004            a list of class names or indices; their uncertainty scores will be computed separately and stacked
4005        n_frames : int, default 10000
4006            the number of frames to "annotate"
4007        method : {"least_confidence", "entropy"}
4008            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
4009            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
4010        min_length : int
4011            the minimum length (in frames) of the annotated intervals
4012        augment_n : int, default 0
4013            the number of augmentations to average over
4014        parameters_update : dict, optional
4015            the dictionary used to update the parameters from the config
4016
4017        Returns
4018        -------
4019        score_dicts : dict
4020            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
4021            are score tensors
4022        """
4023
4024        print(f"UPDATING {mask_name}")
4025        task, parameters, _ = self._make_task_prediction(
4026            prediction_name=mask_name,
4027            load_episode=episode_name,
4028            parameters_update=parameters_update,
4029            load_epoch=load_epoch,
4030            mode="train",
4031        )
4032        score_tensors = task.generate_uncertainty_score(classes, augment_n, method)
4033        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
4034        print("\n")
4035
4036    def _update_mask_with_BALD(
4037        self,
4038        mask_name: str,
4039        episode_name: str,
4040        classes: List,
4041        load_epoch: int = None,
4042        augment_n: int = 0,
4043        n_frames: int = 10000,
4044        num_models: int = 10,
4045        kernel_size: int = 11,
4046        min_length: int = 30,
4047        parameters_update: Dict = None,
4048    ):
4049        """
4050        Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning
4051
4052        Parameters
4053        ----------
4054        mask_name : str
4055            the name of the real_lens
4056        episode_name : str
4057            the name of the episode to load
4058        classes : list
4059            a list of class names or indices; their uncertainty scores will be computed separately and stacked
4060        augment_n : int, default 0
4061            the number of augmentations to average over
4062        n_frames : int, default 10000
4063            the number of frames to "annotate"
4064        num_models : int, default 10
4065            the number of dropout masks to apply
4066        kernel_size : int, default 11
4067            the size of the smoothing gaussian kernel
4068        min_length : int
4069            the minimum length (in frames) of the annotated intervals
4070        parameters_update : dict, optional
4071            the dictionary used to update the parameters from the config
4072
4073        Returns
4074        -------
4075        score_dicts : dict
4076            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
4077            are score tensors
4078        """
4079
4080        print(f"UPDATING {mask_name}")
4081        task, parameters, mode = self._make_task_prediction(
4082            mask_name,
4083            load_episode=episode_name,
4084            parameters_update=parameters_update,
4085            load_epoch=load_epoch,
4086        )
4087        score_tensors = task.generate_bald_score(
4088            classes, augment_n, num_models, kernel_size
4089        )
4090        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
4091        print("\n")
4092
4093    def _suggest_intervals(
4094        self,
4095        dataset: BehaviorDataset,
4096        score_tensors: Dict,
4097        n_frames: int,
4098        min_length: int,
4099    ) -> Dict:
4100        """
4101        Suggest intervals with highest score of total length `n_frames`
4102
4103        Parameters
4104        ----------
4105        dataset : BehaviorDataset
4106            the dataset
4107        score_tensors : dict
4108            a dictionary where keys are clip ids and values are framewise score tensors
4109        n_frames : int
4110            the number of frames to "annotate"
4111        min_length : int
4112
4113        Returns
4114        -------
4115        active_learning_intervals : Dict
4116            active learning dictionary with suggested intervals
4117        """
4118
4119        video_intervals, _ = dataset.get_intervals()
4120        taken = {
4121            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
4122        }
4123        annotated = dataset.get_annotated_intervals()
4124        for video_id in video_intervals:
4125            for clip_id in video_intervals[video_id]:
4126                taken[video_id][clip_id] = torch.zeros(
4127                    dataset.get_len(video_id, clip_id)
4128                )
4129                if video_id in annotated and clip_id in annotated[video_id]:
4130                    for start, end in annotated[video_id][clip_id]:
4131                        score_tensors[video_id][clip_id][:, start:end] = -10
4132                        taken[video_id][clip_id][int(start) : int(end)] = 1
4133        n_frames = (
4134            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
4135            + n_frames
4136        )
4137        factor = 1
4138        threshold_start = float(
4139            torch.mean(
4140                torch.tensor(
4141                    [
4142                        torch.mean(
4143                            torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
4144                        )
4145                        for x in score_tensors.values()
4146                    ]
4147                )
4148            )
4149        )
4150        while (
4151            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
4152            < n_frames
4153        ):
4154            threshold = threshold_start * factor
4155            intervals = []
4156            interval_scores = []
4157            key1 = list(score_tensors.keys())[0]
4158            key2 = list(score_tensors[key1].keys())[0]
4159            num_scores = score_tensors[key1][key2].shape[0]
4160            for i in range(num_scores):
4161                v_dict = dataset.find_valleys(
4162                    predicted=score_tensors,
4163                    threshold=threshold,
4164                    min_frames=min_length,
4165                    main_class=i,
4166                    low=False,
4167                )
4168                for v_id, interval_list in v_dict.items():
4169                    intervals += [x + [v_id] for x in interval_list]
4170                    interval_scores += [
4171                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
4172                        for start, end, clip_id in interval_list
4173                    ]
4174            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
4175            i = 0
4176            while sum(
4177                [(vv == 1).sum() for v in taken.values() for vv in v.values()]
4178            ) < n_frames and i < len(intervals):
4179                start, end, clip_id, video_id = intervals[i]
4180                i += 1
4181                taken[video_id][clip_id][int(start) : int(end)] = 1
4182            factor *= 0.9
4183            if factor < 0.05:
4184                warnings.warn(f"Could not find enough frames!")
4185                break
4186        active_learning_intervals = {video_id: [] for video_id in video_intervals}
4187        for video_id in taken:
4188            for clip_id in taken[video_id]:
4189                if video_id in annotated and clip_id in annotated[video_id]:
4190                    for start, end in annotated[video_id][clip_id]:
4191                        taken[video_id][clip_id][int(start) : int(end)] = 0
4192                if (taken[video_id][clip_id] == 1).sum() == 0:
4193                    continue
4194                indices = np.where(taken[video_id][clip_id].numpy())[0]
4195                boundaries = self._get_intervals(indices)
4196                active_learning_intervals[video_id] += [
4197                    [start, end, clip_id] for start, end in boundaries
4198                ]
4199        return active_learning_intervals
4200
4201    def _update_mask(
4202        self,
4203        task: TaskDispatcher,
4204        mask_name: str,
4205        score_tensors: Dict,
4206        n_frames: int,
4207        min_length: int,
4208    ) -> None:
4209        """
4210        Update the real_lens with intervals with the highest score of total length `n_frames`
4211
4212        Parameters
4213        ----------
4214        mask_name : str
4215            the name of the real_lens
4216        score_tensors : dict
4217            a dictionary where keys are clip ids and values are framewise score tensors
4218        n_frames : int
4219            the number of frames to "annotate"
4220        min_length : int
4221            the minimum length of the annotated intervals
4222        """
4223
4224        mask = self._load_mask(mask_name)
4225        video_intervals, _ = task.dataset("train").get_intervals()
4226        masked = {
4227            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
4228        }
4229        total_masked = 0
4230        total_all = 0
4231        for video_id in video_intervals:
4232            for clip_id in video_intervals[video_id]:
4233                masked[video_id][clip_id] = torch.zeros(
4234                    task.dataset("train").get_len(video_id, clip_id)
4235                )
4236                if (
4237                    video_id in mask["unannotated"]
4238                    and clip_id in mask["unannotated"][video_id]
4239                ):
4240                    for start, end in mask["unannotated"][video_id][clip_id]:
4241                        score_tensors[video_id][clip_id][:, start:end] = -10
4242                        masked[video_id][clip_id][int(start) : int(end)] = 1
4243                if (
4244                    video_id in mask["val_intervals"]
4245                    and clip_id in mask["val_intervals"][video_id]
4246                ):
4247                    for start, end in mask["val_intervals"][video_id][clip_id]:
4248                        score_tensors[video_id][clip_id][:, start:end] = -10
4249                        masked[video_id][clip_id][int(start) : int(end)] = 1
4250                total_all += torch.sum(masked[video_id][clip_id] == 0)
4251                if video_id in mask["masked"] and clip_id in mask["masked"][video_id]:
4252                    # print(f'{real_lens["masked"][video_id][clip_id]=}')
4253                    for start, end in mask["masked"][video_id][clip_id]:
4254                        masked[video_id][clip_id][int(start) : int(end)] = 1
4255                        total_masked += end - start
4256        old_n_frames = sum(
4257            [(vv == 0).sum() for v in masked.values() for vv in v.values()]
4258        )
4259        n_frames = old_n_frames + n_frames
4260        factor = 1
4261        while (
4262            sum([(vv == 0).sum() for v in masked.values() for vv in v.values()])
4263            < n_frames
4264        ):
4265            threshold = float(
4266                torch.mean(
4267                    torch.tensor(
4268                        [
4269                            torch.mean(
4270                                torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
4271                            )
4272                            for x in score_tensors.values()
4273                        ]
4274                    )
4275                )
4276            )
4277            threshold = threshold * factor
4278            intervals = []
4279            interval_scores = []
4280            key1 = list(score_tensors.keys())[0]
4281            key2 = list(score_tensors[key1].keys())[0]
4282            num_scores = score_tensors[key1][key2].shape[0]
4283            for i in range(num_scores):
4284                v_dict = task.dataset("train").find_valleys(
4285                    predicted=score_tensors,
4286                    threshold=threshold,
4287                    min_frames=min_length,
4288                    main_class=i,
4289                    low=False,
4290                )
4291                for v_id, interval_list in v_dict.items():
4292                    intervals += [x + [v_id] for x in interval_list]
4293                    interval_scores += [
4294                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
4295                        for start, end, clip_id in interval_list
4296                    ]
4297            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
4298            i = 0
4299            while sum(
4300                [(vv == 0).sum() for v in masked.values() for vv in v.values()]
4301            ) < n_frames and i < len(intervals):
4302                start, end, clip_id, video_id = intervals[i]
4303                i += 1
4304                masked[video_id][clip_id][int(start) : int(end)] = 0
4305            factor *= 0.9
4306            if factor < 0.05:
4307                warnings.warn(f"Could not find enough frames!")
4308                break
4309        mask["masked"] = {video_id: {} for video_id in video_intervals}
4310        total_masked_new = 0
4311        for video_id in masked:
4312            for clip_id in masked[video_id]:
4313                if (
4314                    video_id in mask["unannotated"]
4315                    and clip_id in mask["unannotated"][video_id]
4316                ):
4317                    for start, end in mask["unannotated"][video_id][clip_id]:
4318                        masked[video_id][clip_id][int(start) : int(end)] = 0
4319                if (
4320                    video_id in mask["val_intervals"]
4321                    and clip_id in mask["val_intervals"][video_id]
4322                ):
4323                    for start, end in mask["val_intervals"][video_id][clip_id]:
4324                        masked[video_id][clip_id][int(start) : int(end)] = 0
4325                indices = np.where(masked[video_id][clip_id].numpy())[0]
4326                mask["masked"][video_id][clip_id] = self._get_intervals(indices)
4327        for video_id in mask["masked"]:
4328            for clip_id in mask["masked"][video_id]:
4329                for start, end in mask["masked"][video_id][clip_id]:
4330                    total_masked_new += end - start
4331        self._save_mask(mask, mask_name)
4332        with open(
4333            os.path.join(self.project_path, "results", f"{mask_name}.txt"), "a"
4334        ) as f:
4335            f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n")
4336        print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}")
4337
4338    def plot_confusion_matrix(
4339        self,
4340        episode_name: str,
4341        load_epoch: int = None,
4342        parameters_update: Dict = None,
4343        type: str = "recall",
4344        mode: str = "val",
4345        remove_saved_features: bool = False,
4346    ) -> Tuple[ndarray, Iterable]:
4347        """
4348        Make a confusion matrix plot and return the data
4349
4350        If the annotation is non-exclusive, only false positive labels are considered.
4351
4352        Parameters
4353        ----------
4354        episode_name : str
4355            the name of the episode to load
4356        load_epoch : int, optional
4357            the index of the epoch to load (by default the last one is loaded)
4358        parameters_update : dict, optional
4359            a dictionary of parameter updates (only for "data" and "general" categories)
4360        mode : {'val', 'all', 'test', 'train'}
4361            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
4362        type : {"recall", "precision"}
4363            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
4364            into account, and if `type` is `"precision"`, only false negatives
4365        remove_saved_features : bool, default False
4366            if `True`, the dataset that is used for computation is then deleted
4367
4368        Returns
4369        -------
4370        confusion_matrix : np.ndarray
4371            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
4372            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
4373            `N_i` is the number of frames that have the i-th label in the ground truth
4374        classes : list
4375            a list of labels
4376        """
4377
4378        task, parameters, mode = self._make_task_prediction(
4379            "_",
4380            load_episode=episode_name,
4381            load_epoch=load_epoch,
4382            parameters_update=parameters_update,
4383            mode=mode,
4384        )
4385        dataset = task.dataset(mode)
4386        prediction = task.predict(dataset, raw_output=True)
4387        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
4388        if remove_saved_features:
4389            self._remove_stores(parameters)
4390        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
4391        ax.imshow(confusion_matrix)
4392        # Show all ticks and label them with the respective list entries
4393        ax.set_xticks(np.arange(len(classes)))
4394        ax.set_xticklabels(classes)
4395        ax.set_yticks(np.arange(len(classes)))
4396        ax.set_yticklabels(classes)
4397        # Rotate the tick labels and set their alignment.
4398        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
4399        # Loop over data dimensions and create text annotations.
4400        for i in range(len(classes)):
4401            for j in range(len(classes)):
4402                ax.text(
4403                    j,
4404                    i,
4405                    np.round(confusion_matrix[i, j], 2),
4406                    ha="center",
4407                    va="center",
4408                    color="w",
4409                )
4410        if type is not None:
4411            ax.set_title(f"{type} {episode_name}")
4412        else:
4413            ax.set_title(episode_name)
4414        fig.tight_layout()
4415        plt.show()
4416        return confusion_matrix, classes
4417
4418    def plot_predictions(
4419        self,
4420        episode_name: str,
4421        load_epoch: int = None,
4422        parameters_update: Dict = None,
4423        add_legend: bool = True,
4424        ground_truth: bool = True,
4425        colormap: str = "viridis",
4426        hide_axes: bool = False,
4427        min_classes: int = 1,
4428        width: float = 10,
4429        whole_video: bool = False,
4430        transparent: bool = False,
4431        drop_classes: Set = None,
4432        search_classes: Set = None,
4433        num_plots: int = 1,
4434        remove_saved_features: bool = False,
4435        smooth_interval_prediction: int = 0,
4436        data_path: str = None,
4437        file_paths: Set = None,
4438        mode: str = "val",
4439        behavior_name: str = None,
4440    ) -> None:
4441        """
4442        Visualize random predictions
4443
4444        Parameters
4445        ----------
4446        episode_name : str
4447            the name of the episode to load
4448        load_epoch : int, optional
4449            the epoch to load (by default last)
4450        parameters_update : dict, optional
4451            parameter update dictionary
4452        add_legend : bool, default True
4453            if True, legend will be added to the plot
4454        ground_truth : bool, default True
4455            if True, ground truth will be added to the plot
4456        colormap : str, default 'Accent'
4457            the `matplotlib` colormap to use
4458        hide_axes : bool, default True
4459            if `True`, the axes will be hidden on the plot
4460        min_classes : int, default 1
4461            the minimum number of classes in a displayed interval
4462        width : float, default 10
4463            the width of the plot
4464        whole_video : bool, default False
4465            if `True`, whole videos are plotted instead of segments
4466        transparent : bool, default False
4467            if `True`, the background on the plot is transparent
4468        drop_classes : set, optional
4469            a set of class names to not be displayed
4470        search_classes : set, optional
4471            if given, only intervals where at least one of the classes is in ground truth will be shown
4472        num_plots : int, default 1
4473            the number of plots to make
4474        remove_saved_features : bool, default False
4475            if `True`, the dataset will be deleted after computation
4476        smooth_interval_prediction : int, default 0
4477            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
4478        data_path : str, optional
4479            the data path to run the prediction for
4480        mode : {'all', 'test', 'val', 'train'}
4481            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
4482        file_paths : set, optional
4483            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
4484            for
4485        behavior_name : str, optional
4486            for non-exclusive classificaton datasets, choose which behavior to visualize (by default first in list)
4487        """
4488
4489        other_path = os.path.join(self.project_path, "results", "other")
4490        task, parameters, mode = self._make_task_prediction(
4491            "_",
4492            load_episode=episode_name,
4493            parameters_update=parameters_update,
4494            load_epoch=load_epoch,
4495            data_path=data_path,
4496            file_paths=file_paths,
4497            mode=mode,
4498        )
4499        if not os.path.exists(other_path):
4500            os.mkdir(other_path)
4501        for i in range(num_plots):
4502            task.visualize_results(
4503                save_path=os.path.join(
4504                    other_path, f"{episode_name}_prediction_{i}.jpg"
4505                ),
4506                add_legend=add_legend,
4507                ground_truth=ground_truth,
4508                colormap=colormap,
4509                hide_axes=hide_axes,
4510                min_classes=min_classes,
4511                whole_video=whole_video,
4512                transparent=transparent,
4513                dataset=mode,
4514                drop_classes=drop_classes,
4515                search_classes=search_classes,
4516                width=width,
4517                smooth_interval_prediction=smooth_interval_prediction,
4518                behavior_name=behavior_name,
4519            )
4520        if remove_saved_features:
4521            self._remove_stores(parameters)
4522
4523    def create_metadata_backup(self) -> None:
4524        """
4525        Create a copy of the meta files
4526        """
4527
4528        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
4529        meta_path = os.path.join(self.project_path, "meta")
4530        if os.path.exists(meta_copy_path):
4531            shutil.rmtree(meta_copy_path)
4532        os.mkdir(meta_copy_path)
4533        for file in os.listdir(meta_path):
4534            if file == "backup":
4535                continue
4536            shutil.copy(
4537                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
4538            )
4539
4540    def load_metadata_backup(self) -> None:
4541        """
4542        Load from previously created meta data backup (in case of corruption)
4543        """
4544
4545        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
4546        meta_path = os.path.join(self.project_path, "meta")
4547        for file in os.listdir(meta_copy_path):
4548            shutil.copy(
4549                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
4550            )
4551
4552    def get_behavior_dictionary(self, episode_name: str) -> Dict:
4553        """
4554        Get the behavior dictionary for an episode
4555
4556        Parameters
4557        ----------
4558        episode_name : str
4559            the name of the episode
4560
4561        Returns
4562        -------
4563        behaviors_dictionary : dict
4564            a dictionary where keys are label indices and values are label names
4565        """
4566
4567        run = self._episodes().get_runs(episode_name)[0]
4568        return self._episode(run).get_behaviors_dict()
4569
4570    def import_episodes(
4571        self,
4572        episodes_directory: str,
4573        name_map: Dict = None,
4574        repeat_policy: str = "error",
4575    ) -> None:
4576        """
4577        Import episodes exported with `Project.export_episodes`
4578
4579        Parameters
4580        ----------
4581        episodes_directory : str
4582            the path to the exported episodes directory
4583        name_map : dict
4584            a name change dictionary for the episodes: keys are old names, values are new names
4585        """
4586
4587        if name_map is None:
4588            name_map = {}
4589        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
4590        to_remove = []
4591        import_string = "Imported episodes: "
4592        for episode_name in episodes.index:
4593            if episode_name in name_map:
4594                import_string += f"{episode_name} "
4595                episode_name = name_map[episode_name]
4596                import_string += f"({episode_name}), "
4597            else:
4598                import_string += f"{episode_name}, "
4599            try:
4600                self._check_episode_validity(episode_name, allow_doublecolon=True)
4601            except ValueError as e:
4602                if str(e).endswith("is already taken!"):
4603                    if repeat_policy == "skip":
4604                        to_remove.append(episode_name)
4605                    elif repeat_policy == "force":
4606                        self.remove_episode(episode_name)
4607                    elif repeat_policy == "error":
4608                        raise ValueError(
4609                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
4610                        )
4611                    else:
4612                        raise ValueError(
4613                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' ans 'error']"
4614                        )
4615        episodes = episodes.drop(index=to_remove)
4616        self._episodes().update(
4617            episodes,
4618            name_map=name_map,
4619            force=(repeat_policy == "force"),
4620            data_path=self.data_path,
4621            annotation_path=self.annotation_path,
4622        )
4623        for episode_name in episodes.index:
4624            if episode_name in name_map:
4625                new_episode_name = name_map[episode_name]
4626            else:
4627                new_episode_name = episode_name
4628            model_dir = os.path.join(
4629                self.project_path, "results", "model", new_episode_name
4630            )
4631            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
4632            if os.path.exists(model_dir):
4633                shutil.rmtree(model_dir)
4634            os.mkdir(model_dir)
4635            for file in os.listdir(old_model_dir):
4636                shutil.copyfile(
4637                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
4638                )
4639            log_file = os.path.join(
4640                self.project_path, "results", "logs", f"{new_episode_name}.txt"
4641            )
4642            old_log_file = os.path.join(
4643                episodes_directory, "logs", f"{episode_name}.txt"
4644            )
4645            shutil.copyfile(old_log_file, log_file)
4646        print(import_string)
4647        print("\n")
4648
4649    def export_episodes(
4650        self, episode_names: List, output_directory: str, name: str = None
4651    ) -> None:
4652        """
4653        Save selected episodes as a file that can be imported into another project with `Project.import_episodes`
4654
4655        Parameters
4656        ----------
4657        episode_names : list
4658            a list of string episode names
4659        output_directory : str
4660            the path to the directory where the episodes will be saved
4661        name : str, optional
4662            the name of the episodes directory (by default `exported_episodes`)
4663        """
4664
4665        if name is None:
4666            name = "exported_episodes"
4667        if os.path.exists(
4668            os.path.join(output_directory, name + ".zip")
4669        ) or os.path.exists(os.path.join(output_directory, name)):
4670            i = 1
4671            while os.path.exists(
4672                os.path.join(output_directory, name + f"_{i}.zip")
4673            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
4674                i += 1
4675            name = name + f"_{i}"
4676        dest_dir = os.path.join(output_directory, name)
4677        os.mkdir(dest_dir)
4678        os.mkdir(os.path.join(dest_dir, "model"))
4679        os.mkdir(os.path.join(dest_dir, "logs"))
4680        runs = []
4681        for episode in episode_names:
4682            runs += self._episodes().get_runs(episode)
4683        for run in runs:
4684            shutil.copytree(
4685                os.path.join(self.project_path, "results", "model", run),
4686                os.path.join(dest_dir, "model", run),
4687            )
4688            shutil.copyfile(
4689                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
4690                os.path.join(dest_dir, "logs", f"{run}.txt"),
4691            )
4692        data = self._episodes().get_subset(runs)
4693        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))
4694
4695    def get_results_table(
4696        self,
4697        episode_names: List,
4698        metrics: List = None,
4699        include_std: bool = False,
4700        classes: List = None,
4701    ):
4702        """
4703        Genererate a `pandas` dataframe with a summary of episode results
4704
4705        Parameters
4706        ----------
4707        episode_names : list
4708            a list of names of episodes to include
4709        metrics : list, optional
4710            a list of metric names to include
4711        include_std : bool, default False
4712            if `True`, for episodes with multiple runs the mean and standard deviation will be displayed;
4713            otherwise only mean
4714        classes : list, optional
4715            a list of names of classes to include (by default all are included)
4716
4717        Returns
4718        -------
4719        results : pd.DataFrame
4720            a table with the results
4721        """
4722
4723        run_names = []
4724        for episode in episode_names:
4725            run_names += self._episodes().get_runs(episode)
4726        episodes = self.list_episodes(run_names, print_results=False)
4727        metric_columns = [x for x in episodes.columns if x[0] == "results"]
4728        results_df = pd.DataFrame()
4729        if metrics is not None:
4730            metric_columns = [
4731                x for x in metric_columns if x[1].split("_")[0] in metrics
4732            ]
4733        for episode in episode_names:
4734            results = []
4735            metric_set = set()
4736            for run in self._episodes().get_runs(episode):
4737                beh_dict = self.get_behavior_dictionary(run)
4738                res_dict = defaultdict(lambda: {})
4739                for column in metric_columns:
4740                    if np.isnan(episodes.loc[run, column]):
4741                        continue
4742                    split = column[1].split("_")
4743                    if split[-1].isnumeric():
4744                        beh_ind = int(split[-1])
4745                        metric_name = "_".join(split[:-1])
4746                        beh = beh_dict[beh_ind]
4747                    else:
4748                        beh = "average"
4749                        metric_name = column[1]
4750                    res_dict[beh][metric_name] = episodes.loc[run, column]
4751                    metric_set.add(metric_name)
4752                if "average" not in res_dict:
4753                    res_dict["average"] = {}
4754                for metric in metric_set:
4755                    if metric not in res_dict["average"]:
4756                        arr = [
4757                            res_dict[beh][metric]
4758                            for beh in res_dict
4759                            if metric in res_dict[beh]
4760                        ]
4761                        res_dict["average"][metric] = np.mean(arr)
4762                results.append(res_dict)
4763            episode_results = {}
4764            for metric in metric_set:
4765                for beh in results[0].keys():
4766                    if classes is not None and beh not in classes:
4767                        continue
4768                    arr = []
4769                    for res_dict in results:
4770                        if metric in res_dict[beh]:
4771                            arr.append(res_dict[beh][metric])
4772                    if len(arr) > 0:
4773                        if include_std:
4774                            episode_results[
4775                                (beh, f"{episode} {metric} mean")
4776                            ] = np.mean(arr)
4777                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
4778                                arr
4779                            )
4780                        else:
4781                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
4782            for key, value in episode_results.items():
4783                results_df.loc[key[0], key[1]] = value
4784        print(f"RESULTS:")
4785        print(results_df)
4786        print("\n")
4787        return results_df
4788
4789    def episode_exists(self, episode_name: str) -> bool:
4790        """
4791        Check if an episode already exists
4792
4793        Parameters
4794        ----------
4795        episode_name : str
4796            the episode name
4797
4798        Returns
4799        -------
4800        exists : bool
4801            `True` if the episode exists
4802        """
4803
4804        return self._episodes().check_name_validity(episode_name)
4805
4806    def search_exists(self, search_name: str) -> bool:
4807        """
4808        Check if a search already exists
4809
4810        Parameters
4811        ----------
4812        search_name : str
4813            the search name
4814
4815        Returns
4816        -------
4817        exists : bool
4818            `True` if the search exists
4819        """
4820
4821        return self._searches().check_name_validity(search_name)
4822
4823    def prediction_exists(self, prediction_name: str) -> bool:
4824        """
4825        Check if a prediction already exists
4826
4827        Parameters
4828        ----------
4829        prediction_name : str
4830            the prediction name
4831
4832        Returns
4833        -------
4834        exists : bool
4835            `True` if the prediction exists
4836        """
4837
4838        return self._predictions().check_name_validity(prediction_name)
4839
4840    @staticmethod
4841    def project_name_available(projects_path: str, project_name: str):
4842        if projects_path is None:
4843            projects_path = os.path.join(str(Path.home()), "DLC2Action")
4844        return not os.path.exists(os.path.join(projects_path, project_name))
4845
4846    def _update_episode_metrics(self, episode_name: str, metrics: Dict):
4847        """
4848        Update meta data with evaluation results
4849        """
4850
4851        self._episodes().update_episode_metrics(episode_name, metrics)
4852
4853    def rename_episode(self, episode_name: str, new_episode_name: str):
4854        shutil.move(
4855            os.path.join(self.project_path, "results", "model", episode_name),
4856            os.path.join(self.project_path, "results", "model", new_episode_name),
4857        )
4858        shutil.move(
4859            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
4860            os.path.join(
4861                self.project_path, "results", "logs", f"{new_episode_name}.txt"
4862            ),
4863        )
4864        self._episodes().rename_episode(episode_name, new_episode_name)

A class to create and maintain the project files + keep track of experiments

Project( name: str, data_type: str = None, annotation_type: str = 'none', projects_path: str = None, data_path: Union[str, List] = None, annotation_path: Union[str, List] = None, copy: bool = False)
 56    def __init__(
 57        self,
 58        name: str,
 59        data_type: str = None,
 60        annotation_type: str = "none",
 61        projects_path: str = None,
 62        data_path: Union[str, List] = None,
 63        annotation_path: Union[str, List] = None,
 64        copy: bool = False,
 65    ) -> None:
 66        """
 67        Parameters
 68        ----------
 69        name : str
 70            name of the project
 71        data_type : str, optional
 72            data type (run Project.data_types() to see available options; has to be provided if the project is being
 73            created)
 74        annotation_type : str, default 'none'
 75            annotation type (run Project.annotation_types() to see available options)
 76        projects_path : str, optional
 77            path to the projects folder (is filled with ~/DLC2Action by default)
 78        data_path : str, optional
 79            path to the folder containing input files for the project (has to be provided if the project is being
 80            created)
 81        annotation_path : str, optional
 82            path to the folder containing annotation files for the project
 83        copy : bool, default False
 84            if True, the files from annotation_path and data_path will be copied to the projects folder;
 85            otherwise they will be moved
 86        """
 87
 88        if projects_path is None:
 89            projects_path = os.path.join(str(Path.home()), "DLC2Action")
 90        if not os.path.exists(projects_path):
 91            os.mkdir(projects_path)
 92        self.project_path = os.path.join(projects_path, name)
 93        self.name = name
 94        self.data_type = data_type
 95        self.annotation_type = annotation_type
 96        self.data_path = data_path
 97        self.annotation_path = annotation_path
 98        if not os.path.exists(self.project_path):
 99            if data_type is None:
100                raise ValueError(
101                    "The data_type parameter is necessary when creating a new project!"
102                )
103            self._initialize_project(
104                data_type, annotation_type, data_path, annotation_path, copy
105            )
106        else:
107            self.annotation_type, self.data_type = self._read_types()
108            if data_type != self.data_type and data_type is not None:
109                raise ValueError(
110                    f"The project has already been initialized with data_type={self.data_type}!"
111                )
112            if annotation_type != self.annotation_type and annotation_type != "none":
113                raise ValueError(
114                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
115                )
116            self.annotation_path, data_path = self._read_paths()
117            if self.data_path is None:
118                self.data_path = data_path
119            # if data_path != self.data_path and data_path is not None:
120            #     raise ValueError(
121            #         f"The project has already been initialized with data_path={self.data_path}!"
122            #     )
123            if annotation_path != self.annotation_path and annotation_path is not None:
124                raise ValueError(
125                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
126                )
127        self._update_configs()

Parameters

name : str name of the project data_type : str, optional data type (run Project.data_types() to see available options; has to be provided if the project is being created) annotation_type : str, default 'none' annotation type (run Project.annotation_types() to see available options) projects_path : str, optional path to the projects folder (is filled with ~/DLC2Action by default) data_path : str, optional path to the folder containing input files for the project (has to be provided if the project is being created) annotation_path : str, optional path to the folder containing annotation files for the project copy : bool, default False if True, the files from annotation_path and data_path will be copied to the projects folder; otherwise they will be moved

def run_episode( self, episode_name: str, load_episode: str = None, parameters_update: Dict = None, task: dlc2action.task.task_dispatcher.TaskDispatcher = None, load_epoch: int = None, load_search: str = None, load_parameters: list = None, round_to_binary: list = None, load_strict: bool = True, n_seeds: int = 1, force: bool = False, suppress_name_check: bool = False, remove_saved_features: bool = False, mask_name: str = None, autostop_metric: str = None, autostop_interval: int = 50, autostop_threshold: float = 0.001, loading_bar: bool = False, trial: Tuple = None) -> dlc2action.task.task_dispatcher.TaskDispatcher:
618    def run_episode(
619        self,
620        episode_name: str,
621        load_episode: str = None,
622        parameters_update: Dict = None,
623        task: TaskDispatcher = None,
624        load_epoch: int = None,
625        load_search: str = None,
626        load_parameters: list = None,
627        round_to_binary: list = None,
628        load_strict: bool = True,
629        n_seeds: int = 1,
630        force: bool = False,
631        suppress_name_check: bool = False,
632        remove_saved_features: bool = False,
633        mask_name: str = None,
634        autostop_metric: str = None,
635        autostop_interval: int = 50,
636        autostop_threshold: float = 0.001,
637        loading_bar: bool = False,
638        trial: Tuple = None,
639    ) -> TaskDispatcher:
640        """
641        Run an episode
642
643        The task parameters are read from the config files and then updated with the
644        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
645        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
646        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
647        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
648        data parameters are used.
649
650        You can use the autostop parameters to finish training when the parameters are not improving. It will be
651        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
652        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
653        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
654
655        Parameters
656        ----------
657        episode_name : str
658            the episode name
659        load_episode : str, optional
660            the (previously run) episode name to load the model from; if the episode has multiple runs,
661            the new episode will have the same number of runs, each starting with one of the pre-trained models
662        parameters_update : dict, optional
663            the dictionary used to update the parameters from the config files
664        task : TaskDispatcher, optional
665            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
666            a new instance)
667        load_epoch : int, optional
668            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
669        load_search : str, optional
670            the hyperparameter search result to load
671        load_parameters : list, optional
672            a list of string names of the parameters to load from load_search (if not provided, all parameters
673            are loaded)
674        round_to_binary : list, optional
675            a list of string names of the loaded parameters that should be rounded to the nearest power of two
676        load_strict : bool, default True
677            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
678            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
679        n_seeds : int, default 1
680            the number of runs to perform with different random seeds; if `n_seeds > 1`, the episodes will be named
681            `episode_name::seed_index`, e.g. `test_episode::0` and `test_episode::1`
682        force : bool, default False
683            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
684        suppress_name_check : bool, default False
685            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
686            why they are usually forbidden)
687        remove_saved_features : bool, default False
688            if `True`, the dataset will be deleted after training
689        mask_name : str, optional
690            the name of the real_lens to apply
691        autostop_interval : int, default 50
692            the number of epochs to average the autostop metric over
693        autostop_threshold : float, default 0.001
694            the autostop difference threshold
695        autostop_metric : str, optional
696            the autostop metric (can be any one of the tracked metrics of `'loss'`)
697        """
698
699        if type(n_seeds) is not int or n_seeds < 1:
700            raise ValueError(
701                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
702            )
703        if n_seeds > 1 and mask_name is not None:
704            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
705        self._check_episode_validity(
706            episode_name, allow_doublecolon=suppress_name_check, force=force
707        )
708        load_runs = self._episodes().get_runs(load_episode)
709        if len(load_runs) > 1:
710            task = self.run_episodes(
711                episode_names=[
712                    f'{episode_name}::{run.split("::")[-1]}' for run in load_runs
713                ],
714                load_episodes=load_runs,
715                parameters_updates=[parameters_update for _ in load_runs],
716                load_epochs=[load_epoch for _ in load_runs],
717                load_searches=[load_search for _ in load_runs],
718                load_parameters=[load_parameters for _ in load_runs],
719                round_to_binary=[round_to_binary for _ in load_runs],
720                load_strict=[load_strict for _ in load_runs],
721                suppress_name_check=True,
722                force=force,
723                remove_saved_features=False,
724            )
725            if remove_saved_features:
726                self._remove_stores(
727                    {
728                        "general": task.general_parameters,
729                        "data": task.data_parameters,
730                        "features": task.feature_parameters,
731                    }
732                )
733            if n_seeds > 1:
734                warnings.warn(
735                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
736                )
737        elif n_seeds > 1:
738            self.run_episodes(
739                episode_names=[f"{episode_name}::{i}" for i in range(n_seeds)],
740                load_episodes=[load_episode for _ in range(n_seeds)],
741                parameters_updates=[parameters_update for _ in range(n_seeds)],
742                load_epochs=[load_epoch for _ in range(n_seeds)],
743                load_searches=[load_search for _ in range(n_seeds)],
744                load_parameters=[load_parameters for _ in range(n_seeds)],
745                round_to_binary=[round_to_binary for _ in range(n_seeds)],
746                load_strict=[load_strict for _ in range(n_seeds)],
747                suppress_name_check=True,
748                force=force,
749                remove_saved_features=remove_saved_features,
750            )
751        else:
752            print(f"TRAINING {episode_name}")
753            try:
754                task, parameters = self._make_task_training(
755                    episode_name,
756                    load_episode,
757                    parameters_update,
758                    load_epoch,
759                    load_search,
760                    load_parameters,
761                    round_to_binary,
762                    continuing=False,
763                    task=task,
764                    mask_name=mask_name,
765                    load_strict=load_strict,
766                )
767                self._save_episode(
768                    episode_name,
769                    parameters,
770                    task.behaviors_dict(),
771                    norm_stats=task.get_normalization_stats(),
772                )
773                time_start = time.time()
774                if trial is not None:
775                    trial, metric = trial
776                else:
777                    trial, metric = None, None
778                logs = task.train(
779                    autostop_metric=autostop_metric,
780                    autostop_interval=autostop_interval,
781                    autostop_threshold=autostop_threshold,
782                    loading_bar=loading_bar,
783                    trial=trial,
784                    optimized_metric=metric,
785                )
786                time_end = time.time()
787                time_total = time_end - time_start
788                hours = int(time_total // 3600)
789                time_total -= hours * 3600
790                minutes = int(time_total // 60)
791                time_total -= minutes * 60
792                seconds = int(time_total)
793                training_time = f"{hours}:{minutes:02}:{seconds:02}"
794                self._update_episode_results(episode_name, logs, training_time)
795                if remove_saved_features:
796                    self._remove_stores(parameters)
797                print("\n")
798                return task
799
800            except Exception as e:
801                if isinstance(e, optuna.exceptions.TrialPruned):
802                    raise e
803                else:
804                    # if str(e) != f"The {episode_name} episode name is already in use!":
805                    #     self.remove_episode(episode_name)
806                    raise RuntimeError(f"Episode {episode_name} could not run")

Run an episode

The task parameters are read from the config files and then updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.

You can use the autostop parameters to finish training when the parameters are not improving. It will be stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than the average over the previous autostop_interval epochs + autostop_threshold. For example, if the current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.

Parameters

episode_name : str the episode name load_episode : str, optional the (previously run) episode name to load the model from; if the episode has multiple runs, the new episode will have the same number of runs, each starting with one of the pre-trained models parameters_update : dict, optional the dictionary used to update the parameters from the config files task : TaskDispatcher, optional a pre-existing TaskDispatcher object (if provided, the method will update it instead of creating a new instance) load_epoch : int, optional the epoch to load (if load_episodes is not None); if not provided, the last epoch is used load_search : str, optional the hyperparameter search result to load load_parameters : list, optional a list of string names of the parameters to load from load_search (if not provided, all parameters are loaded) round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two load_strict : bool, default True if False, matching weights will be loaded from load_episode and differences in parameter name lists and weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError n_seeds : int, default 1 the number of runs to perform with different random seeds; if n_seeds > 1, the episodes will be named episode_name::seed_index, e.g. test_episode::0 and test_episode::1 force : bool, default False if True and an episode with name episode_name already exists, it will be overwritten (use with caution!) suppress_name_check : bool, default False if True, episode names with a double colon are allowed (please don't use this option unless you understand why they are usually forbidden) remove_saved_features : bool, default False if True, the dataset will be deleted after training mask_name : str, optional the name of the real_lens to apply autostop_interval : int, default 50 the number of epochs to average the autostop metric over autostop_threshold : float, default 0.001 the autostop difference threshold autostop_metric : str, optional the autostop metric (can be any one of the tracked metrics of 'loss')

def run_episodes( self, episode_names: List, load_episodes: List = None, parameters_updates: List = None, load_epochs: List = None, load_searches: List = None, load_parameters: List = None, round_to_binary: List = None, load_strict: List = None, force: bool = False, suppress_name_check: bool = False, remove_saved_features: bool = False) -> dlc2action.task.task_dispatcher.TaskDispatcher:
808    def run_episodes(
809        self,
810        episode_names: List,
811        load_episodes: List = None,
812        parameters_updates: List = None,
813        load_epochs: List = None,
814        load_searches: List = None,
815        load_parameters: List = None,
816        round_to_binary: List = None,
817        load_strict: List = None,
818        force: bool = False,
819        suppress_name_check: bool = False,
820        remove_saved_features: bool = False,
821    ) -> TaskDispatcher:
822        """
823        Run multiple episodes in sequence (and re-use previously loaded information)
824
825        For each episode, the task parameters are read from the config files and then updated with the
826        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
827        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
828        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
829        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
830        data parameters are used.
831
832        Parameters
833        ----------
834        episode_names : list
835            a list of strings of episode names
836        load_episodes : list, optional
837            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
838            the new episode will have the same number of runs, each starting with one of the pre-trained models
839        parameters_updates : list, optional
840            a list of dictionaries used to update the parameters from the config
841        load_epochs : list, optional
842            a list of integers used to specify the epoch to load (if load_episodes is not None)
843        load_searches : list, optional
844            a list of strings of hyperparameter search results to load
845        load_parameters : list, optional
846            a list of lists of string names of the parameters to load from the searches
847        round_to_binary : list, optional
848            a list of string names of the loaded parameters that should be rounded to the nearest power of two
849        load_strict : list, optional
850            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
851            the corresponding episode and differences in parameter name lists and
852            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
853            every episode)
854        force : bool, default False
855            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
856        suppress_name_check : bool, default False
857            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
858            why they are usually forbidden)
859        remove_saved_features : bool, default False
860            if `True`, the dataset will be deleted after training
861        """
862
863        task = None
864        if load_searches is None:
865            load_searches = [None for _ in episode_names]
866        if load_episodes is None:
867            load_episodes = [None for _ in episode_names]
868        if parameters_updates is None:
869            parameters_updates = [None for _ in episode_names]
870        if load_parameters is None:
871            load_parameters = [None for _ in episode_names]
872        if load_epochs is None:
873            load_epochs = [None for _ in episode_names]
874        if load_strict is None:
875            load_strict = [True for _ in episode_names]
876        for (
877            parameters_update,
878            episode_name,
879            load_episode,
880            load_epoch,
881            load_search,
882            load_parameters_list,
883            load_strict_value,
884        ) in zip(
885            parameters_updates,
886            episode_names,
887            load_episodes,
888            load_epochs,
889            load_searches,
890            load_parameters,
891            load_strict,
892        ):
893            task = self.run_episode(
894                episode_name,
895                load_episode,
896                parameters_update,
897                task,
898                load_epoch,
899                load_search,
900                load_parameters_list,
901                round_to_binary,
902                load_strict_value,
903                suppress_name_check=suppress_name_check,
904                force=force,
905                remove_saved_features=remove_saved_features,
906            )
907        return task

Run multiple episodes in sequence (and re-use previously loaded information)

For each episode, the task parameters are read from the config files and then updated with the parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.

Parameters

episode_names : list a list of strings of episode names load_episodes : list, optional a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, the new episode will have the same number of runs, each starting with one of the pre-trained models parameters_updates : list, optional a list of dictionaries used to update the parameters from the config load_epochs : list, optional a list of integers used to specify the epoch to load (if load_episodes is not None) load_searches : list, optional a list of strings of hyperparameter search results to load load_parameters : list, optional a list of lists of string names of the parameters to load from the searches round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two load_strict : list, optional a list of boolean values specifying weight loading policy: if False, matching weights will be loaded from the corresponding episode and differences in parameter name lists and weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError (by default True for every episode) force : bool, default False if True and an episode name is already taken, it will be overwritten (use with caution!) suppress_name_check : bool, default False if True, episode names with a double colon are allowed (please don't use this option unless you understand why they are usually forbidden) remove_saved_features : bool, default False if True, the dataset will be deleted after training

def continue_episode( self, episode_name: str, num_epochs: int = None, task: dlc2action.task.task_dispatcher.TaskDispatcher = None, n_seeds: int = 1, remove_saved_features: bool = False, device: str = 'cuda', num_cpus: int = None) -> dlc2action.task.task_dispatcher.TaskDispatcher:
909    def continue_episode(
910        self,
911        episode_name: str,
912        num_epochs: int = None,
913        task: TaskDispatcher = None,
914        n_seeds: int = 1,
915        remove_saved_features: bool = False,
916        device: str = "cuda",
917        num_cpus: int = None,
918    ) -> TaskDispatcher:
919        """
920        Load an older episode and continue running from the latest checkpoint
921
922        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
923
924        Parameters
925        ----------
926        episode_name : str
927            the name of the episode to continue
928        num_epochs : int, optional
929            the new number of epochs
930        task : TaskDispatcher, optional
931            a pre-existing task; if provided, the method will update the task instead of creating a new one
932            (this might save time, mainly on dataset loading)
933        result_average_interval : int, default 5
934            the metric are averaged over the last result_average_interval to be stored in the episodes meta file
935            and displayed by list_episodes() function (the full log is still always available)
936        n_seeds : int, default 1
937            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name::run_index`, e.g.
938            `test_episode::0` and `test_episode::1`
939        remove_saved_features : bool, default False
940            if `True`, pre-computed features will be deleted after the run
941        device : str, default "cuda"
942            the torch device to use
943        """
944
945        runs = self._episodes().get_runs(episode_name)
946        for run in runs:
947            print(f"TRAINING {run}")
948            if num_epochs is None and not self._episode(run).unfinished():
949                continue
950            parameters_update = {
951                "training": {
952                    "num_epochs": num_epochs,
953                    "device": device,
954                },
955                "general": {"num_cpus": num_cpus},
956            }
957            task, parameters = self._make_task_training(
958                run,
959                load_episode=run,
960                parameters_update=parameters_update,
961                continuing=True,
962                task=task,
963            )
964            time_start = time.time()
965            logs = task.train()
966            time_end = time.time()
967            old_time = self._training_time(run)
968            if not np.isnan(old_time):
969                time_end += old_time
970                time_total = time_end - time_start
971                hours = int(time_total // 3600)
972                time_total -= hours * 3600
973                minutes = int(time_total // 60)
974                time_total -= minutes * 60
975                seconds = int(time_total)
976                training_time = f"{hours}:{minutes:02}:{seconds:02}"
977            else:
978                training_time = np.nan
979            self._save_episode(
980                run,
981                parameters,
982                task.behaviors_dict(),
983                suppress_validation=True,
984                training_time=training_time,
985                norm_stats=task.get_normalization_stats(),
986            )
987            self._update_episode_results(run, logs)
988            print("\n")
989        if len(runs) < n_seeds:
990            for i in range(len(runs), n_seeds):
991                self.run_episode(
992                    f"{episode_name}::{i}",
993                    parameters_update=self._episodes().load_parameters(runs[0]),
994                    task=task,
995                    suppress_name_check=True,
996                )
997        if remove_saved_features:
998            self._remove_stores(parameters)
999        return task

Load an older episode and continue running from the latest checkpoint

All parameters as well as the model and optimizer state dictionaries are loaded from the episode.

Parameters

episode_name : str the name of the episode to continue num_epochs : int, optional the new number of epochs task : TaskDispatcher, optional a pre-existing task; if provided, the method will update the task instead of creating a new one (this might save time, mainly on dataset loading) result_average_interval : int, default 5 the metric are averaged over the last result_average_interval to be stored in the episodes meta file and displayed by list_episodes() function (the full log is still always available) n_seeds : int, default 1 the number of runs to perform; if n_seeds > 1, the episodes will be named episode_name::run_index, e.g. test_episode::0 and test_episode::1 remove_saved_features : bool, default False if True, pre-computed features will be deleted after the run device : str, default "cuda" the torch device to use

def run_prediction( self, prediction_name: str, episode_names: List, load_epochs: List = None, parameters_update: Dict = None, augment_n: int = 10, data_path: str = None, mode: str = 'all', file_paths: Set = None, remove_saved_features: bool = False, submission: bool = False, frame_number_map_file: str = None, force: bool = False, embedding: bool = False) -> None:
1236    def run_prediction(
1237        self,
1238        prediction_name: str,
1239        episode_names: List,
1240        load_epochs: List = None,
1241        parameters_update: Dict = None,
1242        augment_n: int = 10,
1243        data_path: str = None,
1244        mode: str = "all",
1245        file_paths: Set = None,
1246        remove_saved_features: bool = False,
1247        submission: bool = False,
1248        frame_number_map_file: str = None,
1249        force: bool = False,
1250        embedding: bool = False,
1251    ) -> None:
1252        """
1253        Load models from previously run episodes to generate a prediction
1254
1255        The probabilities predicted by the models are averaged.
1256        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1257        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1258        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1259        are the prediction arrays.
1260
1261        Parameters
1262        ----------
1263        prediction_name : str
1264            the name of the prediction
1265        episode_names : list
1266            a list of string episode names to load the models from
1267        load_epochs : list, optional
1268            a list of integer epoch indices to load the model from; if None, the last ones are used
1269        parameters_update : dict, optional
1270            a dictionary of parameter updates
1271        augment_n : int, default 10
1272            the number of augmentations to average over
1273        data_path : str, optional
1274            the data path to run the prediction for
1275        mode : {'all', 'test', 'val', 'train'}
1276            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1277        file_paths : set, optional
1278            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1279            for
1280        remove_saved_features : bool, default False
1281            if `True`, pre-computed features will be deleted
1282        submission : bool, default False
1283            if `True`, a MABe-22 style submission file is generated
1284        frame_number_map_file : str, optional
1285            path to the frame number map file
1286        force : bool, default False
1287            if `True`, existing prediction with this name will be overwritten
1288        """
1289
1290        self._check_prediction_validity(prediction_name, force=force)
1291        print(f"PREDICTION {prediction_name}")
1292        if submission:
1293            task = ...
1294            # TODO: add submission option to _make_prediction
1295            predicted = task.generate_submission(
1296                frame_number_map_file=frame_number_map_file,
1297                dataset=mode,
1298                augment_n=augment_n,
1299            )
1300            folder = os.path.join(
1301                self.project_path,
1302                "results",
1303                "predictions",
1304                f"{prediction_name}",
1305            )
1306            filename = os.path.join(folder, f"{prediction_name}.npy")
1307            np.save(filename, predicted, allow_pickle=True)
1308        else:
1309            try:
1310                (
1311                    task,
1312                    parameters,
1313                    mode,
1314                    prediction,
1315                    inference_time,
1316                ) = self._make_prediction(
1317                    prediction_name,
1318                    episode_names,
1319                    load_epochs,
1320                    parameters_update,
1321                    data_path,
1322                    file_paths,
1323                    mode,
1324                    augment_n,
1325                    evaluate=False,
1326                    embedding=embedding,
1327                )
1328                predicted = task.dataset(mode).generate_full_length_prediction(
1329                    prediction
1330                )
1331            except ValueError:
1332                (
1333                    task,
1334                    parameters,
1335                    mode,
1336                    predicted,
1337                    inference_time,
1338                ) = self._aggregate_predictions(
1339                    prediction_name,
1340                    episode_names,
1341                    load_epochs,
1342                    parameters_update,
1343                    data_path,
1344                    file_paths,
1345                    mode,
1346                    augment_n,
1347                    evaluate=False,
1348                    embedding=embedding,
1349                )
1350            folder = self.prediction_path(prediction_name)
1351            os.mkdir(folder)
1352            for video_id, prediction in predicted.items():
1353                with open(
1354                    os.path.join(
1355                        folder, video_id + f"_{prediction_name}_prediction.pickle"
1356                    ),
1357                    "wb",
1358                ) as f:
1359                    prediction["min_frames"], prediction["max_frames"] = task.dataset(
1360                        mode
1361                    ).get_min_max_frames(video_id)
1362                    behavior_indices = sorted(
1363                        [key for key in task.behaviors_dict() if key != -100]
1364                    )
1365                    prediction["behaviors"] = [
1366                        task.behaviors_dict()[key] for key in behavior_indices
1367                    ]
1368                    pickle.dump(prediction, f)
1369        if remove_saved_features:
1370            self._remove_stores(parameters)
1371        self._save_prediction(
1372            prediction_name,
1373            parameters,
1374            task.behaviors_dict(),
1375            embedding,
1376            inference_time,
1377        )
1378        print("\n")

Load models from previously run episodes to generate a prediction

The probabilities predicted by the models are averaged. Unless submission is True, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level keys are the video ids, the second-level keys are the clip ids (like individual names) and the values are the prediction arrays.

Parameters

prediction_name : str the name of the prediction episode_names : list a list of string episode names to load the models from load_epochs : list, optional a list of integer epoch indices to load the model from; if None, the last ones are used parameters_update : dict, optional a dictionary of parameter updates augment_n : int, default 10 the number of augmentations to average over data_path : str, optional the data path to run the prediction for mode : {'all', 'test', 'val', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) file_paths : set, optional a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for remove_saved_features : bool, default False if True, pre-computed features will be deleted submission : bool, default False if True, a MABe-22 style submission file is generated frame_number_map_file : str, optional path to the frame number map file force : bool, default False if True, existing prediction with this name will be overwritten

def evaluate_prediction( self, prediction_name: str, parameters_update: Dict = None, data_path: str = None, file_paths: Set = None, mode: str = None, remove_saved_features: bool = False) -> Tuple[float, dict]:
1380    def evaluate_prediction(
1381        self,
1382        prediction_name: str,
1383        parameters_update: Dict = None,
1384        data_path: str = None,
1385        file_paths: Set = None,
1386        mode: str = None,
1387        remove_saved_features: bool = False,
1388    ) -> Tuple[float, dict]:
1389
1390        with open(
1391            os.path.join(
1392                self.project_path, "results", "predictions", f"{prediction_name}.pickle"
1393            ),
1394            "rb",
1395        ) as f:
1396            prediction = pickle.load(f)
1397        if parameters_update is None:
1398            parameters_update = {}
1399        parameters_update = self._update(
1400            self._predictions().load_parameters(prediction_name), parameters_update
1401        )
1402        parameters_update.pop("model")
1403        task, parameters, mode = self._make_task_prediction(
1404            "_",
1405            load_episode=None,
1406            parameters_update=parameters_update,
1407            data_path=data_path,
1408            file_paths=file_paths,
1409            mode=mode,
1410        )
1411        results = task.evaluate_prediction(prediction, data=mode)
1412        if remove_saved_features:
1413            self._remove_stores(parameters)
1414        print("\n")
1415        return results
def evaluate( self, episode_names: List, load_epochs: List = None, augment_n: int = 0, data_path: str = None, file_paths: Set = None, mode: str = None, parameters_update: Dict = None, multiple_episode_policy: str = 'average', remove_saved_features: bool = False, skip_updating_meta: bool = True) -> Dict:
1417    def evaluate(
1418        self,
1419        episode_names: List,
1420        load_epochs: List = None,
1421        augment_n: int = 0,
1422        data_path: str = None,
1423        file_paths: Set = None,
1424        mode: str = None,
1425        parameters_update: Dict = None,
1426        multiple_episode_policy: str = "average",
1427        remove_saved_features: bool = False,
1428        skip_updating_meta: bool = True,
1429    ) -> Dict:
1430        """
1431        Load one or several models from previously run episodes to make an evaluation
1432
1433        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1434
1435        Parameters
1436        ----------
1437        episode_names : list
1438            a list of string episode names to load the models from
1439        load_epochs : list, optional
1440            a list of integer epoch indices to load the model from; if None, the last ones are used
1441        augment_n : int, default 0
1442            the number of augmentations to average over
1443        data_path : str, optional
1444            the data path to run the prediction for
1445        file_paths : set, optional
1446            a set of files to run the prediction for
1447        mode : {'test', 'val', 'train', 'all'}
1448            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1449            by default 'test' if test subset is not empty and 'val' otherwise)
1450        parameters_update : dict, optional
1451            a dictionary with parameter updates (cannot change model parameters)
1452        remove_saved_features : bool, default False
1453            if `True`, the dataset will be deleted
1454
1455        Returns
1456        -------
1457        metric : dict
1458            a dictionary of average values of metric functions
1459        """
1460
1461        names = []
1462        for episode_name in episode_names:
1463            names += self._episodes().get_runs(episode_name)
1464        if len(set(episode_names)) == 1:
1465            print(f"EVALUATION {episode_names[0]}")
1466        else:
1467            print(f"EVALUATION {episode_names}")
1468        if len(names) > 1:
1469            evaluate = True
1470        else:
1471            evaluate = False
1472        if multiple_episode_policy == "average":
1473            try:
1474                (
1475                    task,
1476                    parameters,
1477                    mode,
1478                    prediction,
1479                    inference_time,
1480                ) = self._make_prediction(
1481                    "_",
1482                    episode_names,
1483                    load_epochs,
1484                    parameters_update,
1485                    mode=mode,
1486                    data_path=data_path,
1487                    file_paths=file_paths,
1488                    augment_n=augment_n,
1489                    evaluate=evaluate,
1490                )
1491            except:
1492                (
1493                    task,
1494                    parameters,
1495                    mode,
1496                    prediction,
1497                    inference_time,
1498                ) = self._aggregate_predictions(
1499                    "_",
1500                    episode_names,
1501                    load_epochs,
1502                    parameters_update,
1503                    mode=mode,
1504                    data_path=data_path,
1505                    file_paths=file_paths,
1506                    augment_n=augment_n,
1507                    evaluate=evaluate,
1508                )
1509            print("AGGREGATED:")
1510            _, results = task.evaluate_prediction(prediction, data=mode)
1511            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1512                self._update_episode_metrics(names[0], results)
1513        elif multiple_episode_policy == "statistics":
1514            values = defaultdict(lambda: [])
1515            task = None
1516            for name in names:
1517                (
1518                    task,
1519                    parameters,
1520                    mode,
1521                    prediction,
1522                    inference_time,
1523                ) = self._make_prediction(
1524                    "_",
1525                    [name],
1526                    load_epochs,
1527                    parameters_update,
1528                    mode=mode,
1529                    data_path=data_path,
1530                    file_paths=file_paths,
1531                    augment_n=augment_n,
1532                    evaluate=evaluate,
1533                    task=task,
1534                )
1535                _, metrics = task.evaluate_prediction(prediction, data=mode)
1536                for name, value in metrics.items():
1537                    values[name].append(value)
1538                if mode == "val" and not skip_updating_meta:
1539                    self._update_episode_metrics(name, metrics)
1540            results = defaultdict(lambda: {})
1541            mean_string = ""
1542            std_string = ""
1543            for key, value_list in values.items():
1544                results[key]["mean"] = np.mean(value_list)
1545                results[key]["std"] = np.std(value_list)
1546                mean_string += f"{key} {np.mean(value_list):.3f}, "
1547                std_string += f"{key} {np.std(value_list):.3f}, "
1548            print("MEAN:")
1549            print(mean_string)
1550            print("STD:")
1551            print(std_string)
1552        else:
1553            raise ValueError(
1554                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1555                f"from ['average', 'statistics']"
1556            )
1557        if len(names) > 0 and remove_saved_features:
1558            self._remove_stores(parameters)
1559        print(f"Inference time: {inference_time}")
1560        print("\n")
1561        return results

Load one or several models from previously run episodes to make an evaluation

By default it will run on the test (or validation, if there is no test) subset of the project dataset.

Parameters

episode_names : list a list of string episode names to load the models from load_epochs : list, optional a list of integer epoch indices to load the model from; if None, the last ones are used augment_n : int, default 0 the number of augmentations to average over data_path : str, optional the data path to run the prediction for file_paths : set, optional a set of files to run the prediction for mode : {'test', 'val', 'train', 'all'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None; by default 'test' if test subset is not empty and 'val' otherwise) parameters_update : dict, optional a dictionary with parameter updates (cannot change model parameters) remove_saved_features : bool, default False if True, the dataset will be deleted

Returns

metric : dict a dictionary of average values of metric functions

def list_episodes( self, episode_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.core.frame.DataFrame:
1639    def list_episodes(
1640        self,
1641        episode_names: List = None,
1642        value_filter: str = "",
1643        display_parameters: List = None,
1644        print_results: bool = True,
1645    ) -> pd.DataFrame:
1646        """
1647        Get a filtered pandas dataframe with episode metadata
1648
1649        Parameters
1650        ----------
1651        episode_names : list
1652            a list of strings of episode names
1653        value_filter : str
1654            a string of filters to apply; of this general structure:
1655            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
1656            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
1657        display_parameters : list
1658            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1659        print_results : bool, default True
1660            if True, the result will be printed to standard output
1661
1662        Returns
1663        -------
1664        pd.DataFrame
1665            the filtered dataframe
1666        """
1667
1668        episodes = self._episodes().list_episodes(
1669            episode_names, value_filter, display_parameters
1670        )
1671        if print_results:
1672            print("TRAINING EPISODES")
1673            print(episodes)
1674            print("\n")
1675        return episodes

Get a filtered pandas dataframe with episode metadata

Parameters

episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1::(/<=/>=/=)value1,group_name2/par_name2::(/<=/>=/=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def list_predictions( self, episode_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.core.frame.DataFrame:
1677    def list_predictions(
1678        self,
1679        episode_names: List = None,
1680        value_filter: str = "",
1681        display_parameters: List = None,
1682        print_results: bool = True,
1683    ) -> pd.DataFrame:
1684        """
1685        Get a filtered pandas dataframe with prediction metadata
1686
1687        Parameters
1688        ----------
1689        episode_names : list
1690            a list of strings of episode names
1691        value_filter : str
1692            a string of filters to apply; of this general structure:
1693            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
1694            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
1695        display_parameters : list
1696            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1697        print_results : bool, default True
1698            if True, the result will be printed to standard output
1699
1700        Returns
1701        -------
1702        pd.DataFrame
1703            the filtered dataframe
1704        """
1705
1706        predictions = self._predictions().list_episodes(
1707            episode_names, value_filter, display_parameters
1708        )
1709        if print_results:
1710            print("PREDICTIONS")
1711            print(predictions)
1712            print("\n")
1713        return predictions

Get a filtered pandas dataframe with prediction metadata

Parameters

episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def list_searches( self, search_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.core.frame.DataFrame:
1715    def list_searches(
1716        self,
1717        search_names: List = None,
1718        value_filter: str = "",
1719        display_parameters: List = None,
1720        print_results: bool = True,
1721    ) -> pd.DataFrame:
1722        """
1723        Get a filtered pandas dataframe with hyperparameter search metadata
1724
1725        Parameters
1726        ----------
1727        search_names : list
1728            a list of strings of search names
1729        value_filter : str
1730            a string of filters to apply; of this general structure:
1731            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
1732            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
1733        display_parameters : list
1734            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
1735        print_results : bool, default True
1736            if True, the result will be printed to standard output
1737
1738        Returns
1739        -------
1740        pd.DataFrame
1741            the filtered dataframe
1742        """
1743
1744        searches = self._searches().list_episodes(
1745            search_names, value_filter, display_parameters
1746        )
1747        if print_results:
1748            print("SEARCHES")
1749            print(searches)
1750            print("\n")
1751        return searches

Get a filtered pandas dataframe with hyperparameter search metadata

Parameters

search_names : list a list of strings of search names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def get_best_parameters(self, search_name: str, round_to_binary: List = None)
1753    def get_best_parameters(
1754        self,
1755        search_name: str,
1756        round_to_binary: List = None,
1757    ):
1758        params, model = self._searches().get_best_params(
1759            search_name, round_to_binary=round_to_binary
1760        )
1761        params = self._update(params, {"general": {"model_name": model}})
1762        return params
def list_best_parameters(self, search_name: str, print_results: bool = True) -> Dict:
1764    def list_best_parameters(
1765        self, search_name: str, print_results: bool = True
1766    ) -> Dict:
1767        """
1768        Get the raw dictionary of best parameters found by a search
1769
1770        Parameters
1771        ----------
1772        search_name : str
1773            the name of the search
1774        print_results : bool, default True
1775            if True, the result will be printed to standard output
1776
1777        Returns
1778        -------
1779        best_params : dict
1780            a dictionary of the best parameters where the keys are in '{group}/{name}' format
1781        """
1782
1783        params = self._searches().get_best_params_raw(search_name)
1784        if print_results:
1785            print(f"SEARCH RESULTS {search_name}")
1786            for k, v in params.items():
1787                print(f"{k}: {v}")
1788            print("\n")
1789        return params

Get the raw dictionary of best parameters found by a search

Parameters

search_name : str the name of the search print_results : bool, default True if True, the result will be printed to standard output

Returns

best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format

def plot_episodes( self, episode_names: List, metrics: List, modes: List = None, title: str = None, episode_labels: List = None, save_path: str = None, add_hlines: List = None, epoch_limits: List = None, colors: List = None, add_highpoint_hlines: bool = False) -> None:
1791    def plot_episodes(
1792        self,
1793        episode_names: List,
1794        metrics: List,
1795        modes: List = None,
1796        title: str = None,
1797        episode_labels: List = None,
1798        save_path: str = None,
1799        add_hlines: List = None,
1800        epoch_limits: List = None,
1801        colors: List = None,
1802        add_highpoint_hlines: bool = False,
1803    ) -> None:
1804        """
1805        Plot episode training curves
1806
1807        Parameters
1808        ----------
1809        episode_names : list
1810            a list of episode names to plot; to plot to episodes in one line combine them in a list
1811            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
1812        metrics : list
1813            a list of metric to plot
1814        modes : list, optional
1815            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
1816        title : str, optional
1817            title for the plot
1818        episode_labels : list, optional
1819            a list of strings used to label the curves (has to be the same length as episode_names)
1820        save_path : str, optional
1821            the path to save the resulting plot
1822        add_hlines : list, optional
1823            a list of float values (or (value, label) tuples) to mark with horizontal lines
1824        colors: list, optional
1825            a list of matplotlib colors
1826        add_highpoint_hlines : bool, default False
1827            if `True`, horizontal lines will be added at the highest value of each episode
1828        """
1829
1830        if modes is None:
1831            modes = ["val"]
1832        if add_hlines is None:
1833            add_hlines = []
1834        logs = []
1835        epochs = []
1836        labels = []
1837        if episode_labels is not None:
1838            assert len(episode_labels) == len(episode_names)
1839        for name_i, name in enumerate(episode_names):
1840            log_params = product(metrics, modes)
1841            for metric, mode in log_params:
1842                if episode_labels is not None:
1843                    label = episode_labels[name_i]
1844                else:
1845                    label = deepcopy(name)
1846                if len(modes) != 1:
1847                    label += f"_{mode}"
1848                if len(metrics) != 1:
1849                    label += f"_{metric}"
1850                labels.append(label)
1851                if isinstance(name, Iterable) and not isinstance(name, str):
1852                    epoch_list = defaultdict(lambda: [])
1853                    multi_logs = defaultdict(lambda: [])
1854                    for i, n in enumerate(name):
1855                        runs = self._episodes().get_runs(n)
1856                        if len(runs) > 1:
1857                            for run in runs:
1858                                index = run.split("::")[-1]
1859                                if multi_logs[index] == []:
1860                                    if multi_logs["null"] is None:
1861                                        raise RuntimeError(
1862                                            "The run indices are not consistent across episodes!"
1863                                        )
1864                                    else:
1865                                        multi_logs[index] += multi_logs["null"]
1866                                multi_logs[index] += list(
1867                                    self._episode(run).get_metric_log(mode, metric)
1868                                )
1869                                start = (
1870                                    0
1871                                    if len(epoch_list[index]) == 0
1872                                    else epoch_list[index][-1]
1873                                )
1874                                epoch_list[index] += [
1875                                    x + start
1876                                    for x in self._episode(run).get_epoch_list(mode)
1877                                ]
1878                            multi_logs["null"] = None
1879                        else:
1880                            if len(multi_logs.keys()) > 1:
1881                                raise RuntimeError(
1882                                    "Cannot plot a single-run episode after a multi-run episode!"
1883                                )
1884                            multi_logs["null"] += list(
1885                                self._episode(n).get_metric_log(mode, metric)
1886                            )
1887                            start = (
1888                                0
1889                                if len(epoch_list["null"]) == 0
1890                                else epoch_list["null"][-1]
1891                            )
1892                            epoch_list["null"] += [
1893                                x + start for x in self._episode(n).get_epoch_list(mode)
1894                            ]
1895                    if len(multi_logs.keys()) == 1:
1896                        log = multi_logs["null"]
1897                        epochs.append(epoch_list["null"])
1898                    else:
1899                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
1900                        epochs.append(
1901                            tuple([v for k, v in epoch_list.items() if k != "null"])
1902                        )
1903                else:
1904                    runs = self._episodes().get_runs(name)
1905                    if len(runs) > 1:
1906                        log = []
1907                        for run in runs:
1908                            tracked_metrics = self._episode(run).get_metrics()
1909                            if metric in tracked_metrics:
1910                                log.append(
1911                                    list(
1912                                        self._episode(run).get_metric_log(mode, metric)
1913                                    )
1914                                )
1915                            else:
1916                                relevant = []
1917                                for m in tracked_metrics:
1918                                    m_split = m.split("_")
1919                                    if (
1920                                        "_".join(m_split[:-1]) == metric
1921                                        and m_split[-1].isnumeric()
1922                                    ):
1923                                        relevant.append(m)
1924                                if len(relevant) == 0:
1925                                    raise ValueError(
1926                                        f"The {metric} metric was not tracked at {run}"
1927                                    )
1928                                arr = 0
1929                                for m in relevant:
1930                                    arr += self._episode(run).get_metric_log(mode, m)
1931                                arr /= len(relevant)
1932                                log.append(list(arr))
1933                        log = tuple(log)
1934                        epochs.append(
1935                            tuple(
1936                                [
1937                                    self._episode(run).get_epoch_list(mode)
1938                                    for run in runs
1939                                ]
1940                            )
1941                        )
1942                    else:
1943                        tracked_metrics = self._episode(name).get_metrics()
1944                        if metric in tracked_metrics:
1945                            log = list(self._episode(name).get_metric_log(mode, metric))
1946                        else:
1947                            relevant = []
1948                            for m in tracked_metrics:
1949                                m_split = m.split("_")
1950                                if (
1951                                    "_".join(m_split[:-1]) == metric
1952                                    and m_split[-1].isnumeric()
1953                                ):
1954                                    relevant.append(m)
1955                            if len(relevant) == 0:
1956                                raise ValueError(
1957                                    f"The {metric} metric was not tracked at {name}"
1958                                )
1959                            arr = 0
1960                            for m in relevant:
1961                                arr += self._episode(name).get_metric_log(mode, m)
1962                            arr /= len(relevant)
1963                            log = list(arr)
1964                        epochs.append(self._episode(name).get_epoch_list(mode))
1965                logs.append(log)
1966        # if episode_labels is not None:
1967        #     print(f'{len(episode_labels)=}, {len(logs)=}')
1968        #     if len(episode_labels) != len(logs):
1969
1970        #         raise ValueError(
1971        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
1972        #             f"curves ({len(logs)})!"
1973        #         )
1974        #     else:
1975        #         labels = episode_labels
1976        if colors is None:
1977            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
1978        if len(colors) != len(logs):
1979            raise ValueError(
1980                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
1981            )
1982        plt.figure()
1983        length = 0
1984        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
1985            if type(log) is list:
1986                if len(log) > length:
1987                    length = len(log)
1988                plt.plot(
1989                    epoch_list,
1990                    log,
1991                    label=label,
1992                    color=color,
1993                )
1994                if add_highpoint_hlines:
1995                    plt.axhline(np.max(log), linestyle="dashed", color=color)
1996            else:
1997                for l, xx in zip(log, epoch_list):
1998                    if len(l) > length:
1999                        length = len(l)
2000                    plt.plot(
2001                        xx,
2002                        l,
2003                        color=color,
2004                        alpha=0.2,
2005                    )
2006                if not all([len(x) == len(log[0]) for x in log]):
2007                    warnings.warn(
2008                        f"Got logs with unequal lengths in parallel runs for {label}"
2009                    )
2010                    log = list(log)
2011                    epoch_list = list(epoch_list)
2012                    for i, x in enumerate(epoch_list):
2013                        to_remove = []
2014                        for j, y in enumerate(x[1:]):
2015                            if y <= x[j - 1]:
2016                                y_ind = x.index(y)
2017                                to_remove += list(range(y_ind, j))
2018                        epoch_list[i] = [
2019                            y for j, y in enumerate(x) if j not in to_remove
2020                        ]
2021                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2022                    length = min([len(x) for x in log])
2023                    for i in range(len(log)):
2024                        log[i] = log[i][:length]
2025                        epoch_list[i] = epoch_list[i][:length]
2026                    if not all([x == epoch_list[0] for x in epoch_list]):
2027                        raise RuntimeError(
2028                            f"Got different epoch indices in parallel runs for {label}"
2029                        )
2030                mean = np.array(log).mean(0)
2031                plt.plot(
2032                    epoch_list[0],
2033                    mean,
2034                    label=label,
2035                    color=color,
2036                )
2037                if add_highpoint_hlines:
2038                    plt.axhline(np.max(mean), linestyle="dashed", color=color)
2039        for x in add_hlines:
2040            label = None
2041            if isinstance(x, Iterable):
2042                x, label = x
2043            plt.axhline(x, label=label)
2044            plt.xlim((0, length))
2045
2046        plt.legend()
2047        plt.xlabel("epochs")
2048        if len(metrics) == 1:
2049            plt.ylabel(metrics[0])
2050        else:
2051            plt.ylabel("value")
2052        if title is None:
2053            if len(episode_names) == 1:
2054                title = episode_names[0]
2055            elif len(metrics) == 1:
2056                title = metrics[0]
2057        if epoch_limits is not None:
2058            plt.xlim(epoch_limits)
2059        if title is not None:
2060            plt.title(title)
2061        plt.show()
2062        if save_path is not None:
2063            plt.savefig(save_path)

Plot episode training curves

Parameters

episode_names : list a list of episode names to plot; to plot to episodes in one line combine them in a list (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) metrics : list a list of metric to plot modes : list, optional a list of modes to plot ('train' and/or 'val'; ['val'] by default) title : str, optional title for the plot episode_labels : list, optional a list of strings used to label the curves (has to be the same length as episode_names) save_path : str, optional the path to save the resulting plot add_hlines : list, optional a list of float values (or (value, label) tuples) to mark with horizontal lines colors: list, optional a list of matplotlib colors add_highpoint_hlines : bool, default False if True, horizontal lines will be added at the highest value of each episode

def update_parameters( self, parameters_update: Dict = None, load_search: str = None, load_parameters: List = None, round_to_binary: List = None) -> None:
2065    def update_parameters(
2066        self,
2067        parameters_update: Dict = None,
2068        load_search: str = None,
2069        load_parameters: List = None,
2070        round_to_binary: List = None,
2071    ) -> None:
2072        """
2073        Update the parameters in the project config files
2074
2075        Parameters
2076        ----------
2077        parameters_update : dict, optional
2078            a dictionary of parameter updates
2079        load_search : str, optional
2080            the name of hyperparameter search results to load to config
2081        load_parameters : list, optional
2082            a list of lists of string names of the parameters to load from the searches
2083        round_to_binary : list, optional
2084            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2085        """
2086
2087        keys = [
2088            "general",
2089            "losses",
2090            "metrics",
2091            "ssl",
2092            "training",
2093            "data",
2094        ]
2095        parameters = self._read_parameters(catch_blanks=False)
2096        if parameters_update is not None:
2097            if "model" in parameters_update:
2098                model_params = parameters_update.pop("model")
2099            else:
2100                model_params = None
2101            if "features" in parameters_update:
2102                feat_params = parameters_update.pop("features")
2103            else:
2104                feat_params = None
2105            if "augmentations" in parameters_update:
2106                aug_params = parameters_update.pop("augmentations")
2107            else:
2108                aug_params = None
2109            parameters = self._update(parameters, parameters_update)
2110            model_name = parameters["general"]["model_name"]
2111            parameters["model"] = self._open_yaml(
2112                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2113            )
2114            if model_params is not None:
2115                parameters["model"] = self._update(parameters["model"], model_params)
2116            feat_name = parameters["general"]["feature_extraction"]
2117            parameters["features"] = self._open_yaml(
2118                os.path.join(
2119                    self.project_path, "config", "features", f"{feat_name}.yaml"
2120                )
2121            )
2122            if feat_params is not None:
2123                parameters["features"] = self._update(
2124                    parameters["features"], feat_params
2125                )
2126            aug_name = options.extractor_to_transformer[
2127                parameters["general"]["feature_extraction"]
2128            ]
2129            parameters["augmentations"] = self._open_yaml(
2130                os.path.join(
2131                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2132                )
2133            )
2134            if aug_params is not None:
2135                parameters["augmentations"] = self._update(
2136                    parameters["augmentations"], aug_params
2137                )
2138        if load_search is not None:
2139            parameters_update, model_name = self._searches().get_best_params(
2140                load_search, load_parameters, round_to_binary
2141            )
2142            parameters["general"]["model_name"] = model_name
2143            parameters["model"] = self._open_yaml(
2144                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2145            )
2146            parameters = self._update(parameters, parameters_update)
2147        for key in keys:
2148            with open(
2149                os.path.join(self.project_path, "config", f"{key}.yaml"), "w", encoding="utf-8"
2150            ) as f:
2151                YAML().dump(parameters[key], f)
2152        model_name = parameters["general"]["model_name"]
2153        model_path = os.path.join(
2154            self.project_path, "config", "model", f"{model_name}.yaml"
2155        )
2156        with open(model_path, "w", encoding="utf-8") as f:
2157            YAML().dump(parameters["model"], f)
2158        features_name = parameters["general"]["feature_extraction"]
2159        features_path = os.path.join(
2160            self.project_path, "config", "features", f"{features_name}.yaml"
2161        )
2162        with open(features_path, "w", encoding="utf-8") as f:
2163            YAML().dump(parameters["features"], f)
2164        aug_name = options.extractor_to_transformer[features_name]
2165        aug_path = os.path.join(
2166            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2167        )
2168        with open(aug_path, "w", encoding="utf-8") as f:
2169            YAML().dump(parameters["augmentations"], f)

Update the parameters in the project config files

Parameters

parameters_update : dict, optional a dictionary of parameter updates load_search : str, optional the name of hyperparameter search results to load to config load_parameters : list, optional a list of lists of string names of the parameters to load from the searches round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two

def get_summary( self, episode_names: list, method: str = 'last', average: int = 1, metrics: List = None) -> Dict:
2171    def get_summary(
2172        self,
2173        episode_names: list,
2174        method: str = "last",
2175        average: int = 1,
2176        metrics: List = None,
2177    ) -> Dict:
2178        """
2179        Get a summary of episode statistics
2180
2181        If the episode has multiple runs, the statistics will be aggregated over all of them.
2182
2183        Parameters
2184        ----------
2185        episode_name : str
2186            the name of the episode
2187        method : ["best", "last"]
2188            the method for choosing the epochs
2189        average : int, default 1
2190            the number of epochs to average over (for each run)
2191        metrics : list, optional
2192            a list of metrics
2193
2194        Returns
2195        -------
2196        statistics : dict
2197            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
2198            and 'std' for the standard deviation
2199        """
2200
2201        runs = []
2202        for episode_name in episode_names:
2203            runs_ep = self._episodes().get_runs(episode_name)
2204            if len(runs_ep) == 0:
2205                raise RuntimeError(
2206                    f"There is no {episode_name} episode in the project memory"
2207                )
2208            runs += runs_ep
2209        if metrics is None:
2210            metrics = self._episode(runs[0]).get_metrics()
2211
2212        values = {m: [] for m in metrics}
2213        for run in runs:
2214            for m in metrics:
2215                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
2216                if method == "best":
2217                    log = sorted(log)
2218                    values[m] += list(log[-average:])
2219                elif method == "last":
2220                    if len(log) == 0:
2221                        episodes = self._episodes().data
2222                        if average == 1 and ("results", m) in episodes.columns:
2223                            values[m] += [episodes.loc[run, ("results", m)]]
2224                        else:
2225                            raise RuntimeError(f"Did not find {m} metric for {run} run")
2226                    values[m] += list(log[-average:])
2227                elif method.startswith("epoch"):
2228                    epoch = int(method[5:]) - 1
2229                    pars = self._episodes().load_parameters(run)
2230                    step = int(pars["training"]["validation_interval"])
2231                    values[m] += [log[epoch // step]]
2232                else:
2233                    raise ValueError(
2234                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
2235                    )
2236        statistics = defaultdict(lambda: {})
2237        for m, v in values.items():
2238            statistics[m]["mean"] = np.mean(v)
2239            statistics[m]["std"] = np.std(v)
2240        print(f"SUMMARY {episode_names}")
2241        for m, v in statistics.items():
2242            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
2243        print("\n")
2244        return dict(statistics)

Get a summary of episode statistics

If the episode has multiple runs, the statistics will be aggregated over all of them.

Parameters

episode_name : str the name of the episode method : ["best", "last"] the method for choosing the epochs average : int, default 1 the number of epochs to average over (for each run) metrics : list, optional a list of metrics

Returns

statistics : dict a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean and 'std' for the standard deviation

@staticmethod
def remove_project(name: str, projects_path: str = None) -> None:
2246    @staticmethod
2247    def remove_project(name: str, projects_path: str = None) -> None:
2248        """
2249        Remove all project files and experiment records and results
2250        """
2251
2252        if projects_path is None:
2253            projects_path = os.path.join(str(Path.home()), "DLC2Action")
2254        project_path = os.path.join(projects_path, name)
2255        if os.path.exists(project_path):
2256            shutil.rmtree(project_path)

Remove all project files and experiment records and results

def remove_saved_features( self, dataset_names: List = None, exceptions: List = None, remove_active: bool = False) -> None:
2258    def remove_saved_features(
2259        self,
2260        dataset_names: List = None,
2261        exceptions: List = None,
2262        remove_active: bool = False,
2263    ) -> None:
2264        """
2265        Remove saved pre-computed dataset files
2266
2267        By default, all pre-computed features will be deleted.
2268        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
2269        while training or inference is happening though.
2270
2271        Parameters
2272        ----------
2273        dataset_names : list, optional
2274            a list of dataset names to delete (by default all names are added)
2275        exceptions : list, optional
2276            a list of dataset names to not be deleted
2277        remove_active : bool, default False
2278            if `False`, datasets used by unfinished episodes will not be deleted
2279        """
2280
2281        print("Removing datasets...")
2282        if dataset_names is None:
2283            dataset_names = []
2284        if exceptions is None:
2285            exceptions = []
2286        if not remove_active:
2287            exceptions += self._episodes().get_active_datasets()
2288        dataset_path = os.path.join(self.project_path, "saved_datasets")
2289        if os.path.exists(dataset_path):
2290            if dataset_names == []:
2291                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
2292
2293            to_remove = [
2294                x
2295                for x in dataset_names
2296                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
2297            ]
2298            if len(to_remove) > 2:
2299                to_remove = tqdm(to_remove)
2300            for dataset in to_remove:
2301                shutil.rmtree(os.path.join(dataset_path, dataset))
2302            to_remove = [
2303                f"{x}.pickle"
2304                for x in dataset_names
2305                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
2306                and x not in exceptions
2307            ]
2308            for dataset in to_remove:
2309                os.remove(os.path.join(dataset_path, dataset))
2310            names = self._saved_datasets().dataset_names()
2311            self._saved_datasets().remove(names)
2312        print("\n")

Remove saved pre-computed dataset files

By default, all pre-computed features will be deleted. No essential information can get lost, storing them only saves time. Be careful with deleting datasets while training or inference is happening though.

Parameters

dataset_names : list, optional a list of dataset names to delete (by default all names are added) exceptions : list, optional a list of dataset names to not be deleted remove_active : bool, default False if False, datasets used by unfinished episodes will not be deleted

def remove_extra_checkpoints(self, episode_names: List = None, exceptions: List = None) -> None:
2314    def remove_extra_checkpoints(
2315        self, episode_names: List = None, exceptions: List = None
2316    ) -> None:
2317        """
2318        Remove intermediate model checkpoint files (only leave the results of the last epoch)
2319
2320        By default, all intermediate checkpoints will be deleted.
2321        Files in the model folder that are not associated with any record in the meta files are also deleted.
2322
2323        Parameters
2324        ----------
2325        episode_names : list, optional
2326            a list of episode names to clean (by default all names are added)
2327        exceptions : list, optional
2328            a list of episode names to not clean
2329        """
2330
2331        model_path = os.path.join(self.project_path, "results", "model")
2332        try:
2333            all_names = self._episodes().data.index
2334        except:
2335            all_names = os.listdir(model_path)
2336        if episode_names is None:
2337            episode_names = all_names
2338        if exceptions is None:
2339            exceptions = []
2340        to_remove = [x for x in episode_names if x not in exceptions]
2341        folders = os.listdir(model_path)
2342        for folder in folders:
2343            if folder not in all_names:
2344                shutil.rmtree(os.path.join(model_path, folder))
2345            elif folder in to_remove:
2346                files = os.listdir(os.path.join(model_path, folder))
2347                for file in sorted(files)[:-1]:
2348                    os.remove(os.path.join(model_path, folder, file))

Remove intermediate model checkpoint files (only leave the results of the last epoch)

By default, all intermediate checkpoints will be deleted. Files in the model folder that are not associated with any record in the meta files are also deleted.

Parameters

episode_names : list, optional a list of episode names to clean (by default all names are added) exceptions : list, optional a list of episode names to not clean

def remove_prediction(self, prediction_name: str) -> None:
2365    def remove_prediction(self, prediction_name: str) -> None:
2366        """
2367        Remove a prediction record
2368
2369        Parameters
2370        ----------
2371        prediction_name : str
2372            the name of the prediction to remove
2373        """
2374
2375        self._predictions().remove_episode(prediction_name)
2376        prediction_path = os.path.join(
2377            self.project_path, "results", "predictions", prediction_name
2378        )
2379        if os.path.exists(prediction_path):
2380            shutil.rmtree(prediction_path)

Remove a prediction record

Parameters

prediction_name : str the name of the prediction to remove

def remove_episode(self, episode_name: str) -> None:
2382    def remove_episode(self, episode_name: str) -> None:
2383        """
2384        Remove all model, logs and metafile records related to an episode
2385
2386        Parameters
2387        ----------
2388        episode_name : str
2389            the name of the episode to remove
2390        """
2391
2392        runs = self._episodes().get_runs(episode_name)
2393        runs.append(episode_name)
2394        for run in runs:
2395            self._episodes().remove_episode(run)
2396            model_path = os.path.join(self.project_path, "results", "model", run)
2397            if os.path.exists(model_path):
2398                shutil.rmtree(model_path)
2399            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
2400            if os.path.exists(log_path):
2401                os.remove(log_path)

Remove all model, logs and metafile records related to an episode

Parameters

episode_name : str the name of the episode to remove

def prune_unfinished(self, exceptions: List = None) -> None:
2403    def prune_unfinished(self, exceptions: List = None) -> None:
2404        """
2405        Remove all interrupted episodes
2406
2407        Remove all episodes that either don't have a log file or have less epochs in the log file than in
2408        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
2409        currently running!
2410
2411        Parameters
2412        ----------
2413        exceptions : list
2414            the episodes to keep even if they are interrupted
2415
2416        Returns
2417        -------
2418        pruned : list
2419            a list of the episode names that were pruned
2420        """
2421
2422        if exceptions is None:
2423            exceptions = []
2424        unfinished = self._episodes().unfinished_episodes()
2425        unfinished = [x for x in unfinished if x not in exceptions]
2426        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
2427        unfinished += [
2428            x for x in model_folders if x not in self._episodes().list_episodes().index
2429        ]
2430        print(f"PRUNING {unfinished}")
2431        for episode_name in unfinished:
2432            self.remove_episode(episode_name)
2433        print(f"\n")
2434        return unfinished

Remove all interrupted episodes

Remove all episodes that either don't have a log file or have less epochs in the log file than in the training parameters or have a model folder but not a record. Note that it can remove episodes that are currently running!

Parameters

exceptions : list the episodes to keep even if they are interrupted

Returns

pruned : list a list of the episode names that were pruned

def prediction_path(self, prediction_name: str) -> str:
2436    def prediction_path(self, prediction_name: str) -> str:
2437        """
2438        Get the path where prediction files are saved
2439
2440        Parameters
2441        ----------
2442        prediction_name : str
2443            name of the prediction
2444
2445        Returns
2446        -------
2447        prediction_path : str
2448            the file path
2449        """
2450
2451        return os.path.join(
2452            self.project_path, "results", "predictions", f"{prediction_name}"
2453        )

Get the path where prediction files are saved

Parameters

prediction_name : str name of the prediction

Returns

prediction_path : str the file path

@classmethod
def print_data_types(cls)
2455    @classmethod
2456    def print_data_types(cls):
2457        print("DATA TYPES:")
2458        for key, value in cls.data_types().items():
2459            print(f"{key}:")
2460            print(value.__doc__)
@classmethod
def print_annotation_types(cls)
2462    @classmethod
2463    def print_annotation_types(cls):
2464        print("ANNOTATION TYPES:")
2465        for key, value in cls.annotation_types().items():
2466            print(f"{key}:")
2467            print(value.__doc__)
@staticmethod
def data_types() -> List:
2469    @staticmethod
2470    def data_types() -> List:
2471        """
2472        Get available data types
2473
2474        Returns
2475        -------
2476        list
2477            available data types
2478        """
2479
2480        return options.input_stores

Get available data types

Returns

list available data types

@staticmethod
def annotation_types() -> List:
2482    @staticmethod
2483    def annotation_types() -> List:
2484        """
2485        Get available annotation types
2486
2487        Returns
2488        -------
2489        list
2490            available annotation types
2491        """
2492
2493        return options.annotation_stores

Get available annotation types

Returns

list available annotation types

def set_main_parameters(self, model_name: str = None, metric_names: List = None)
3075    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3076        """
3077        Select the model and the metrics
3078
3079        Parameters
3080        ----------
3081        model_name : str, optional
3082            model name; run `project.help("model") to find out more
3083        metric_names : list, optional
3084            a list of metric function names; run `project.help("metrics") to find out more
3085        """
3086
3087        pars = {"general": {}}
3088        if model_name is not None:
3089            assert model_name in options.models
3090            pars["general"]["model_name"] = model_name
3091        if metric_names is not None:
3092            for metric in metric_names:
3093                assert metric in options.metrics
3094            pars["general"]["metric_functions"] = metric_names
3095        self.update_parameters(pars)

Select the model and the metrics

Parameters

model_name : str, optional model name; run project.help("model") to find out more metric_names : list, optional a list of metric function names; runproject.help("metrics") to find out more

def help(self, keyword: str = None)
3097    def help(self, keyword: str = None):
3098        """
3099        Get information on available options
3100
3101        Parameters
3102        ----------
3103        keyword : str, optional
3104            the keyword for options (run without arguments to see which keywords are available)
3105
3106        """
3107
3108        if keyword is None:
3109            print("AVAILABLE HELP FUNCTIONS:")
3110            print("- Try running `project.help(keyword)` with the following keywords:")
3111            print("    - model: to get more information on available models,")
3112            print(
3113                "    - features: to get more information on available feature extraction modes,"
3114            )
3115            print(
3116                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3117            )
3118            print("    - metrics: to see a list of available metric functions.")
3119            print("    - data: to see help for expected data structure")
3120            print(
3121                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
3122            )
3123            print(
3124                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
3125            )
3126            print(
3127                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
3128            )
3129        elif keyword == "model":
3130            print("MODELS:")
3131            for key, model in options.models.items():
3132                print(f"{key}:")
3133                print(model.__doc__)
3134        elif keyword == "features":
3135            print("FEATURE EXTRACTORS:")
3136            for key, extractor in options.feature_extractors.items():
3137                print(f"{key}:")
3138                print(extractor.__doc__)
3139        elif keyword == "partition_method":
3140            print("PARTITION METHODS:")
3141            print(
3142                BehaviorDataset.partition_train_test_val.__doc__.split(
3143                    "The partitioning method:"
3144                )[1].split("val_frac :")[0]
3145            )
3146        elif keyword == "metrics":
3147            print("METRICS:")
3148            for key, metric in options.metrics.items():
3149                print(f"{key}:")
3150                print(metric.__doc__)
3151        elif keyword == "data":
3152            print("DATA:")
3153            print(f"Video data: {self.data_type}")
3154            print(options.input_stores[self.data_type].__doc__)
3155            print(f"Annotation data: {self.annotation_type}")
3156            print(options.annotation_stores[self.annotation_type].__doc__)
3157            print(
3158                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
3159            )
3160        else:
3161            raise ValueError(f"The {keyword} keyword is not recognized")
3162        print("\n")

Get information on available options

Parameters

keyword : str, optional the keyword for options (run without arguments to see which keywords are available)

def list_blanks(self, blanks=None)
3182    def list_blanks(self, blanks=None):
3183        """
3184        List parameters that need to be filled in
3185
3186        Parameters
3187        ----------
3188        blanks : list, optional
3189            a list of the parameters to list, if already known
3190        """
3191
3192        if blanks is None:
3193            blanks = self._get_blanks()
3194        if len(blanks) > 0:
3195            to_update = defaultdict(lambda: [])
3196            for b, k, c in blanks:
3197                to_update[b].append((k, c))
3198            print("Before running experiments, please update all the blanks.")
3199            print("To do that, you can run this.")
3200            print("--------------------------------------------------------")
3201            print(f"project.update_parameters(")
3202            print(f"    {{")
3203            for big_key, keys in to_update.items():
3204                print(f'        "{big_key}": {{')
3205                for key, comment in keys:
3206                    print(f'            "{key}": ..., {comment}')
3207                print(f"        }},")
3208            print(f"    }}")
3209            print(")")
3210            print("--------------------------------------------------------")
3211            print("Replace ... with relevant values.")
3212        else:
3213            print("There is no blanks left!")

List parameters that need to be filled in

Parameters

blanks : list, optional a list of the parameters to list, if already known

def list_basic_parameters(self)
3215    def list_basic_parameters(
3216        self,
3217    ):
3218        """
3219        Get a list of most relevant parameters and code to modify them
3220        """
3221
3222        parameters = self._read_parameters()
3223        print("BASIC PARAMETERS:")
3224        model_name = parameters["general"]["model_name"]
3225        metric_names = parameters["general"]["metric_functions"]
3226        loss_name = parameters["general"]["loss_function"]
3227        feature_extraction = parameters["general"]["feature_extraction"]
3228        print("Here is a list of current parameters.")
3229        print(
3230            "You can copy this code, change the parameters you want to set and run it to update the project config."
3231        )
3232        print("--------------------------------------------------------")
3233        print("project.update_parameters(")
3234        print("    {")
3235        for group in ["general", "data", "training"]:
3236            print(f'        "{group}": {{')
3237            for key in options.basic_parameters[group]:
3238                if key in parameters[group]:
3239                    print(
3240                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
3241                    )
3242            print("        },")
3243        print('        "losses": {')
3244        print(f'            "{loss_name}": {{')
3245        for key in options.basic_parameters["losses"][loss_name]:
3246            if key in parameters["losses"][loss_name]:
3247                print(
3248                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
3249                )
3250        print("            },")
3251        print("        },")
3252        print('        "metrics": {')
3253        for metric in metric_names:
3254            print(f'            "{metric}": {{')
3255            for key in parameters["metrics"][metric]:
3256                print(
3257                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
3258                )
3259            print("            },")
3260        print("        },")
3261        print('        "model": {')
3262        for key in options.basic_parameters["model"][model_name]:
3263            if key in parameters["model"]:
3264                print(
3265                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
3266                )
3267
3268        print("        },")
3269        print('        "features": {')
3270        for key in options.basic_parameters["features"][feature_extraction]:
3271            if key in parameters["features"]:
3272                print(
3273                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
3274                )
3275
3276        print("        },")
3277        print('        "augmentations": {')
3278        for key in options.basic_parameters["augmentations"][feature_extraction]:
3279            if key in parameters["augmentations"]:
3280                print(
3281                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
3282                )
3283        print("        },")
3284        print("    },")
3285        print(")")
3286        print("--------------------------------------------------------")
3287        print("\n")

Get a list of most relevant parameters and code to modify them

def count_classes( self, load_episode: str = None, parameters_update: Dict = None, remove_saved_features: bool = False, bouts: bool = True) -> Dict:
3788    def count_classes(
3789        self,
3790        load_episode: str = None,
3791        parameters_update: Dict = None,
3792        remove_saved_features: bool = False,
3793        bouts: bool = True,
3794    ) -> Dict:
3795        """
3796        Get a dictionary of class counts in different modes
3797
3798        Parameters
3799        ----------
3800        load_episode : str, optional
3801            the episode settings to load
3802        parameters_update : dict, optional
3803            a dictionary of parameter updates (only for "data" and "general" categories)
3804        remove_saved_features : bool, default False
3805            if `True`, the dataset that is used for computation is then deleted
3806        bouts : bool, default False
3807            if `True`, instead of frame counts segment counts are returned
3808
3809        Returns
3810        -------
3811        class_counts : dict
3812            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
3813            class names and values are class counts (in frames)
3814        """
3815
3816        if load_episode is None:
3817            task, parameters = self._make_task_training(
3818                episode_name="_", parameters_update=parameters_update, throwaway=True
3819            )
3820        else:
3821            task, parameters, _ = self._make_task_prediction(
3822                "_",
3823                load_episode=load_episode,
3824                parameters_update=parameters_update,
3825            )
3826        class_counts = task.count_classes(bouts=bouts)
3827        behaviors = task.behaviors_dict()
3828        class_counts = {
3829            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
3830            for kk, vv in class_counts.items()
3831        }
3832        if remove_saved_features:
3833            self._remove_stores(parameters)
3834        return class_counts

Get a dictionary of class counts in different modes

Parameters

load_episode : str, optional the episode settings to load parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)

def plot_class_distribution( self, parameters_update: Dict = None, frame_cutoff: int = 1, bout_cutoff: int = 1, print_full: bool = False, remove_saved_features: bool = False) -> None:
3836    def plot_class_distribution(
3837        self,
3838        parameters_update: Dict = None,
3839        frame_cutoff: int = 1,
3840        bout_cutoff: int = 1,
3841        print_full: bool = False,
3842        remove_saved_features: bool = False,
3843    ) -> None:
3844        """
3845        Make a class distribution plot
3846
3847        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
3848        is created or laoded for the computation with the default parameters).
3849
3850        Parameters
3851        ----------
3852        parameters_update : dict, optional
3853            a dictionary of parameter updates (only for "data" and "general" categories)
3854        remove_saved_features : bool, default False
3855            if `True`, the dataset that is used for computation is then deleted
3856        """
3857
3858        task, parameters = self._make_task_training(
3859            episode_name="_", parameters_update=parameters_update, throwaway=True
3860        )
3861        cutoff = {True: bout_cutoff, False: frame_cutoff}
3862        for bouts in [True, False]:
3863            class_counts = task.count_classes(bouts=bouts)
3864            if print_full:
3865                print("Bouts:" if bouts else "Frames:")
3866                for k, v in class_counts.items():
3867                    if sum(v.values()) != 0:
3868                        print(f"  {k}:")
3869                        values, keys = zip(
3870                            *[
3871                                x
3872                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
3873                                if x[-1] != -100
3874                            ]
3875                        )
3876                        for kk, vv in zip(keys, values):
3877                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
3878            class_counts = {
3879                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
3880                for kk, vv in class_counts.items()
3881            }
3882            for key, d in class_counts.items():
3883                if sum(d.values()) != 0:
3884                    values, keys = zip(
3885                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
3886                    )
3887                    keys = [task.behaviors_dict()[x] for x in keys]
3888                    plt.bar(keys, values)
3889                    plt.title(key)
3890                    plt.xticks(rotation=45, ha="right")
3891                    if bouts:
3892                        plt.ylabel("bouts")
3893                    else:
3894                        plt.ylabel("frames")
3895                    plt.tight_layout()
3896                    plt.show()
3897        if remove_saved_features:
3898            self._remove_stores(parameters)

Make a class distribution plot

You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset is created or laoded for the computation with the default parameters).

Parameters

parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted

def plot_confusion_matrix( self, episode_name: str, load_epoch: int = None, parameters_update: Dict = None, type: str = 'recall', mode: str = 'val', remove_saved_features: bool = False) -> Tuple[numpy.ndarray, collections.abc.Iterable]:
4338    def plot_confusion_matrix(
4339        self,
4340        episode_name: str,
4341        load_epoch: int = None,
4342        parameters_update: Dict = None,
4343        type: str = "recall",
4344        mode: str = "val",
4345        remove_saved_features: bool = False,
4346    ) -> Tuple[ndarray, Iterable]:
4347        """
4348        Make a confusion matrix plot and return the data
4349
4350        If the annotation is non-exclusive, only false positive labels are considered.
4351
4352        Parameters
4353        ----------
4354        episode_name : str
4355            the name of the episode to load
4356        load_epoch : int, optional
4357            the index of the epoch to load (by default the last one is loaded)
4358        parameters_update : dict, optional
4359            a dictionary of parameter updates (only for "data" and "general" categories)
4360        mode : {'val', 'all', 'test', 'train'}
4361            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
4362        type : {"recall", "precision"}
4363            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
4364            into account, and if `type` is `"precision"`, only false negatives
4365        remove_saved_features : bool, default False
4366            if `True`, the dataset that is used for computation is then deleted
4367
4368        Returns
4369        -------
4370        confusion_matrix : np.ndarray
4371            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
4372            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
4373            `N_i` is the number of frames that have the i-th label in the ground truth
4374        classes : list
4375            a list of labels
4376        """
4377
4378        task, parameters, mode = self._make_task_prediction(
4379            "_",
4380            load_episode=episode_name,
4381            load_epoch=load_epoch,
4382            parameters_update=parameters_update,
4383            mode=mode,
4384        )
4385        dataset = task.dataset(mode)
4386        prediction = task.predict(dataset, raw_output=True)
4387        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
4388        if remove_saved_features:
4389            self._remove_stores(parameters)
4390        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
4391        ax.imshow(confusion_matrix)
4392        # Show all ticks and label them with the respective list entries
4393        ax.set_xticks(np.arange(len(classes)))
4394        ax.set_xticklabels(classes)
4395        ax.set_yticks(np.arange(len(classes)))
4396        ax.set_yticklabels(classes)
4397        # Rotate the tick labels and set their alignment.
4398        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
4399        # Loop over data dimensions and create text annotations.
4400        for i in range(len(classes)):
4401            for j in range(len(classes)):
4402                ax.text(
4403                    j,
4404                    i,
4405                    np.round(confusion_matrix[i, j], 2),
4406                    ha="center",
4407                    va="center",
4408                    color="w",
4409                )
4410        if type is not None:
4411            ax.set_title(f"{type} {episode_name}")
4412        else:
4413            ax.set_title(episode_name)
4414        fig.tight_layout()
4415        plt.show()
4416        return confusion_matrix, classes

Make a confusion matrix plot and return the data

If the annotation is non-exclusive, only false positive labels are considered.

Parameters

episode_name : str the name of the episode to load load_epoch : int, optional the index of the epoch to load (by default the last one is loaded) parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) mode : {'val', 'all', 'test', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) type : {"recall", "precision"} for datasets with non-exclusive annotation, if type is "recall", only false positives are taken into account, and if type is "precision", only false negatives remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted

Returns

confusion_matrix : np.ndarray a confusion matrix of shape (#classes, #classes) where A[i, j] = F_ij/N_i, F_ij is the number of frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, N_i is the number of frames that have the i-th label in the ground truth classes : list a list of labels

def plot_predictions( self, episode_name: str, load_epoch: int = None, parameters_update: Dict = None, add_legend: bool = True, ground_truth: bool = True, colormap: str = 'viridis', hide_axes: bool = False, min_classes: int = 1, width: float = 10, whole_video: bool = False, transparent: bool = False, drop_classes: Set = None, search_classes: Set = None, num_plots: int = 1, remove_saved_features: bool = False, smooth_interval_prediction: int = 0, data_path: str = None, file_paths: Set = None, mode: str = 'val', behavior_name: str = None) -> None:
4418    def plot_predictions(
4419        self,
4420        episode_name: str,
4421        load_epoch: int = None,
4422        parameters_update: Dict = None,
4423        add_legend: bool = True,
4424        ground_truth: bool = True,
4425        colormap: str = "viridis",
4426        hide_axes: bool = False,
4427        min_classes: int = 1,
4428        width: float = 10,
4429        whole_video: bool = False,
4430        transparent: bool = False,
4431        drop_classes: Set = None,
4432        search_classes: Set = None,
4433        num_plots: int = 1,
4434        remove_saved_features: bool = False,
4435        smooth_interval_prediction: int = 0,
4436        data_path: str = None,
4437        file_paths: Set = None,
4438        mode: str = "val",
4439        behavior_name: str = None,
4440    ) -> None:
4441        """
4442        Visualize random predictions
4443
4444        Parameters
4445        ----------
4446        episode_name : str
4447            the name of the episode to load
4448        load_epoch : int, optional
4449            the epoch to load (by default last)
4450        parameters_update : dict, optional
4451            parameter update dictionary
4452        add_legend : bool, default True
4453            if True, legend will be added to the plot
4454        ground_truth : bool, default True
4455            if True, ground truth will be added to the plot
4456        colormap : str, default 'Accent'
4457            the `matplotlib` colormap to use
4458        hide_axes : bool, default True
4459            if `True`, the axes will be hidden on the plot
4460        min_classes : int, default 1
4461            the minimum number of classes in a displayed interval
4462        width : float, default 10
4463            the width of the plot
4464        whole_video : bool, default False
4465            if `True`, whole videos are plotted instead of segments
4466        transparent : bool, default False
4467            if `True`, the background on the plot is transparent
4468        drop_classes : set, optional
4469            a set of class names to not be displayed
4470        search_classes : set, optional
4471            if given, only intervals where at least one of the classes is in ground truth will be shown
4472        num_plots : int, default 1
4473            the number of plots to make
4474        remove_saved_features : bool, default False
4475            if `True`, the dataset will be deleted after computation
4476        smooth_interval_prediction : int, default 0
4477            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
4478        data_path : str, optional
4479            the data path to run the prediction for
4480        mode : {'all', 'test', 'val', 'train'}
4481            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
4482        file_paths : set, optional
4483            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
4484            for
4485        behavior_name : str, optional
4486            for non-exclusive classificaton datasets, choose which behavior to visualize (by default first in list)
4487        """
4488
4489        other_path = os.path.join(self.project_path, "results", "other")
4490        task, parameters, mode = self._make_task_prediction(
4491            "_",
4492            load_episode=episode_name,
4493            parameters_update=parameters_update,
4494            load_epoch=load_epoch,
4495            data_path=data_path,
4496            file_paths=file_paths,
4497            mode=mode,
4498        )
4499        if not os.path.exists(other_path):
4500            os.mkdir(other_path)
4501        for i in range(num_plots):
4502            task.visualize_results(
4503                save_path=os.path.join(
4504                    other_path, f"{episode_name}_prediction_{i}.jpg"
4505                ),
4506                add_legend=add_legend,
4507                ground_truth=ground_truth,
4508                colormap=colormap,
4509                hide_axes=hide_axes,
4510                min_classes=min_classes,
4511                whole_video=whole_video,
4512                transparent=transparent,
4513                dataset=mode,
4514                drop_classes=drop_classes,
4515                search_classes=search_classes,
4516                width=width,
4517                smooth_interval_prediction=smooth_interval_prediction,
4518                behavior_name=behavior_name,
4519            )
4520        if remove_saved_features:
4521            self._remove_stores(parameters)

Visualize random predictions

Parameters

episode_name : str the name of the episode to load load_epoch : int, optional the epoch to load (by default last) parameters_update : dict, optional parameter update dictionary add_legend : bool, default True if True, legend will be added to the plot ground_truth : bool, default True if True, ground truth will be added to the plot colormap : str, default 'Accent' the matplotlib colormap to use hide_axes : bool, default True if True, the axes will be hidden on the plot min_classes : int, default 1 the minimum number of classes in a displayed interval width : float, default 10 the width of the plot whole_video : bool, default False if True, whole videos are plotted instead of segments transparent : bool, default False if True, the background on the plot is transparent drop_classes : set, optional a set of class names to not be displayed search_classes : set, optional if given, only intervals where at least one of the classes is in ground truth will be shown num_plots : int, default 1 the number of plots to make remove_saved_features : bool, default False if True, the dataset will be deleted after computation smooth_interval_prediction : int, default 0 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) data_path : str, optional the data path to run the prediction for mode : {'all', 'test', 'val', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) file_paths : set, optional a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for behavior_name : str, optional for non-exclusive classificaton datasets, choose which behavior to visualize (by default first in list)

def create_metadata_backup(self) -> None:
4523    def create_metadata_backup(self) -> None:
4524        """
4525        Create a copy of the meta files
4526        """
4527
4528        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
4529        meta_path = os.path.join(self.project_path, "meta")
4530        if os.path.exists(meta_copy_path):
4531            shutil.rmtree(meta_copy_path)
4532        os.mkdir(meta_copy_path)
4533        for file in os.listdir(meta_path):
4534            if file == "backup":
4535                continue
4536            shutil.copy(
4537                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
4538            )

Create a copy of the meta files

def load_metadata_backup(self) -> None:
4540    def load_metadata_backup(self) -> None:
4541        """
4542        Load from previously created meta data backup (in case of corruption)
4543        """
4544
4545        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
4546        meta_path = os.path.join(self.project_path, "meta")
4547        for file in os.listdir(meta_copy_path):
4548            shutil.copy(
4549                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
4550            )

Load from previously created meta data backup (in case of corruption)

def get_behavior_dictionary(self, episode_name: str) -> Dict:
4552    def get_behavior_dictionary(self, episode_name: str) -> Dict:
4553        """
4554        Get the behavior dictionary for an episode
4555
4556        Parameters
4557        ----------
4558        episode_name : str
4559            the name of the episode
4560
4561        Returns
4562        -------
4563        behaviors_dictionary : dict
4564            a dictionary where keys are label indices and values are label names
4565        """
4566
4567        run = self._episodes().get_runs(episode_name)[0]
4568        return self._episode(run).get_behaviors_dict()

Get the behavior dictionary for an episode

Parameters

episode_name : str the name of the episode

Returns

behaviors_dictionary : dict a dictionary where keys are label indices and values are label names

def import_episodes( self, episodes_directory: str, name_map: Dict = None, repeat_policy: str = 'error') -> None:
4570    def import_episodes(
4571        self,
4572        episodes_directory: str,
4573        name_map: Dict = None,
4574        repeat_policy: str = "error",
4575    ) -> None:
4576        """
4577        Import episodes exported with `Project.export_episodes`
4578
4579        Parameters
4580        ----------
4581        episodes_directory : str
4582            the path to the exported episodes directory
4583        name_map : dict
4584            a name change dictionary for the episodes: keys are old names, values are new names
4585        """
4586
4587        if name_map is None:
4588            name_map = {}
4589        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
4590        to_remove = []
4591        import_string = "Imported episodes: "
4592        for episode_name in episodes.index:
4593            if episode_name in name_map:
4594                import_string += f"{episode_name} "
4595                episode_name = name_map[episode_name]
4596                import_string += f"({episode_name}), "
4597            else:
4598                import_string += f"{episode_name}, "
4599            try:
4600                self._check_episode_validity(episode_name, allow_doublecolon=True)
4601            except ValueError as e:
4602                if str(e).endswith("is already taken!"):
4603                    if repeat_policy == "skip":
4604                        to_remove.append(episode_name)
4605                    elif repeat_policy == "force":
4606                        self.remove_episode(episode_name)
4607                    elif repeat_policy == "error":
4608                        raise ValueError(
4609                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
4610                        )
4611                    else:
4612                        raise ValueError(
4613                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' ans 'error']"
4614                        )
4615        episodes = episodes.drop(index=to_remove)
4616        self._episodes().update(
4617            episodes,
4618            name_map=name_map,
4619            force=(repeat_policy == "force"),
4620            data_path=self.data_path,
4621            annotation_path=self.annotation_path,
4622        )
4623        for episode_name in episodes.index:
4624            if episode_name in name_map:
4625                new_episode_name = name_map[episode_name]
4626            else:
4627                new_episode_name = episode_name
4628            model_dir = os.path.join(
4629                self.project_path, "results", "model", new_episode_name
4630            )
4631            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
4632            if os.path.exists(model_dir):
4633                shutil.rmtree(model_dir)
4634            os.mkdir(model_dir)
4635            for file in os.listdir(old_model_dir):
4636                shutil.copyfile(
4637                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
4638                )
4639            log_file = os.path.join(
4640                self.project_path, "results", "logs", f"{new_episode_name}.txt"
4641            )
4642            old_log_file = os.path.join(
4643                episodes_directory, "logs", f"{episode_name}.txt"
4644            )
4645            shutil.copyfile(old_log_file, log_file)
4646        print(import_string)
4647        print("\n")

Import episodes exported with Project.export_episodes

Parameters

episodes_directory : str the path to the exported episodes directory name_map : dict a name change dictionary for the episodes: keys are old names, values are new names

def export_episodes( self, episode_names: List, output_directory: str, name: str = None) -> None:
4649    def export_episodes(
4650        self, episode_names: List, output_directory: str, name: str = None
4651    ) -> None:
4652        """
4653        Save selected episodes as a file that can be imported into another project with `Project.import_episodes`
4654
4655        Parameters
4656        ----------
4657        episode_names : list
4658            a list of string episode names
4659        output_directory : str
4660            the path to the directory where the episodes will be saved
4661        name : str, optional
4662            the name of the episodes directory (by default `exported_episodes`)
4663        """
4664
4665        if name is None:
4666            name = "exported_episodes"
4667        if os.path.exists(
4668            os.path.join(output_directory, name + ".zip")
4669        ) or os.path.exists(os.path.join(output_directory, name)):
4670            i = 1
4671            while os.path.exists(
4672                os.path.join(output_directory, name + f"_{i}.zip")
4673            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
4674                i += 1
4675            name = name + f"_{i}"
4676        dest_dir = os.path.join(output_directory, name)
4677        os.mkdir(dest_dir)
4678        os.mkdir(os.path.join(dest_dir, "model"))
4679        os.mkdir(os.path.join(dest_dir, "logs"))
4680        runs = []
4681        for episode in episode_names:
4682            runs += self._episodes().get_runs(episode)
4683        for run in runs:
4684            shutil.copytree(
4685                os.path.join(self.project_path, "results", "model", run),
4686                os.path.join(dest_dir, "model", run),
4687            )
4688            shutil.copyfile(
4689                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
4690                os.path.join(dest_dir, "logs", f"{run}.txt"),
4691            )
4692        data = self._episodes().get_subset(runs)
4693        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))

Save selected episodes as a file that can be imported into another project with Project.import_episodes

Parameters

episode_names : list a list of string episode names output_directory : str the path to the directory where the episodes will be saved name : str, optional the name of the episodes directory (by default exported_episodes)

def get_results_table( self, episode_names: List, metrics: List = None, include_std: bool = False, classes: List = None)
4695    def get_results_table(
4696        self,
4697        episode_names: List,
4698        metrics: List = None,
4699        include_std: bool = False,
4700        classes: List = None,
4701    ):
4702        """
4703        Genererate a `pandas` dataframe with a summary of episode results
4704
4705        Parameters
4706        ----------
4707        episode_names : list
4708            a list of names of episodes to include
4709        metrics : list, optional
4710            a list of metric names to include
4711        include_std : bool, default False
4712            if `True`, for episodes with multiple runs the mean and standard deviation will be displayed;
4713            otherwise only mean
4714        classes : list, optional
4715            a list of names of classes to include (by default all are included)
4716
4717        Returns
4718        -------
4719        results : pd.DataFrame
4720            a table with the results
4721        """
4722
4723        run_names = []
4724        for episode in episode_names:
4725            run_names += self._episodes().get_runs(episode)
4726        episodes = self.list_episodes(run_names, print_results=False)
4727        metric_columns = [x for x in episodes.columns if x[0] == "results"]
4728        results_df = pd.DataFrame()
4729        if metrics is not None:
4730            metric_columns = [
4731                x for x in metric_columns if x[1].split("_")[0] in metrics
4732            ]
4733        for episode in episode_names:
4734            results = []
4735            metric_set = set()
4736            for run in self._episodes().get_runs(episode):
4737                beh_dict = self.get_behavior_dictionary(run)
4738                res_dict = defaultdict(lambda: {})
4739                for column in metric_columns:
4740                    if np.isnan(episodes.loc[run, column]):
4741                        continue
4742                    split = column[1].split("_")
4743                    if split[-1].isnumeric():
4744                        beh_ind = int(split[-1])
4745                        metric_name = "_".join(split[:-1])
4746                        beh = beh_dict[beh_ind]
4747                    else:
4748                        beh = "average"
4749                        metric_name = column[1]
4750                    res_dict[beh][metric_name] = episodes.loc[run, column]
4751                    metric_set.add(metric_name)
4752                if "average" not in res_dict:
4753                    res_dict["average"] = {}
4754                for metric in metric_set:
4755                    if metric not in res_dict["average"]:
4756                        arr = [
4757                            res_dict[beh][metric]
4758                            for beh in res_dict
4759                            if metric in res_dict[beh]
4760                        ]
4761                        res_dict["average"][metric] = np.mean(arr)
4762                results.append(res_dict)
4763            episode_results = {}
4764            for metric in metric_set:
4765                for beh in results[0].keys():
4766                    if classes is not None and beh not in classes:
4767                        continue
4768                    arr = []
4769                    for res_dict in results:
4770                        if metric in res_dict[beh]:
4771                            arr.append(res_dict[beh][metric])
4772                    if len(arr) > 0:
4773                        if include_std:
4774                            episode_results[
4775                                (beh, f"{episode} {metric} mean")
4776                            ] = np.mean(arr)
4777                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
4778                                arr
4779                            )
4780                        else:
4781                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
4782            for key, value in episode_results.items():
4783                results_df.loc[key[0], key[1]] = value
4784        print(f"RESULTS:")
4785        print(results_df)
4786        print("\n")
4787        return results_df

Genererate a pandas dataframe with a summary of episode results

Parameters

episode_names : list a list of names of episodes to include metrics : list, optional a list of metric names to include include_std : bool, default False if True, for episodes with multiple runs the mean and standard deviation will be displayed; otherwise only mean classes : list, optional a list of names of classes to include (by default all are included)

Returns

results : pd.DataFrame a table with the results

def episode_exists(self, episode_name: str) -> bool:
4789    def episode_exists(self, episode_name: str) -> bool:
4790        """
4791        Check if an episode already exists
4792
4793        Parameters
4794        ----------
4795        episode_name : str
4796            the episode name
4797
4798        Returns
4799        -------
4800        exists : bool
4801            `True` if the episode exists
4802        """
4803
4804        return self._episodes().check_name_validity(episode_name)

Check if an episode already exists

Parameters

episode_name : str the episode name

Returns

exists : bool True if the episode exists

def search_exists(self, search_name: str) -> bool:
4806    def search_exists(self, search_name: str) -> bool:
4807        """
4808        Check if a search already exists
4809
4810        Parameters
4811        ----------
4812        search_name : str
4813            the search name
4814
4815        Returns
4816        -------
4817        exists : bool
4818            `True` if the search exists
4819        """
4820
4821        return self._searches().check_name_validity(search_name)

Check if a search already exists

Parameters

search_name : str the search name

Returns

exists : bool True if the search exists

def prediction_exists(self, prediction_name: str) -> bool:
4823    def prediction_exists(self, prediction_name: str) -> bool:
4824        """
4825        Check if a prediction already exists
4826
4827        Parameters
4828        ----------
4829        prediction_name : str
4830            the prediction name
4831
4832        Returns
4833        -------
4834        exists : bool
4835            `True` if the prediction exists
4836        """
4837
4838        return self._predictions().check_name_validity(prediction_name)

Check if a prediction already exists

Parameters

prediction_name : str the prediction name

Returns

exists : bool True if the prediction exists

@staticmethod
def project_name_available(projects_path: str, project_name: str)
4840    @staticmethod
4841    def project_name_available(projects_path: str, project_name: str):
4842        if projects_path is None:
4843            projects_path = os.path.join(str(Path.home()), "DLC2Action")
4844        return not os.path.exists(os.path.join(projects_path, project_name))
def rename_episode(self, episode_name: str, new_episode_name: str)
4853    def rename_episode(self, episode_name: str, new_episode_name: str):
4854        shutil.move(
4855            os.path.join(self.project_path, "results", "model", episode_name),
4856            os.path.join(self.project_path, "results", "model", new_episode_name),
4857        )
4858        shutil.move(
4859            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
4860            os.path.join(
4861                self.project_path, "results", "logs", f"{new_episode_name}.txt"
4862            ),
4863        )
4864        self._episodes().rename_episode(episode_name, new_episode_name)