dlc2action.project.project

Project interface

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""
   8Project interface
   9"""
  10
  11import gc
  12import os
  13import pickle
  14import shutil
  15import time
  16import warnings
  17from abc import abstractmethod
  18from collections import defaultdict
  19from collections.abc import Iterable, Mapping
  20from copy import deepcopy
  21from email.policy import default
  22from itertools import product
  23from pathlib import Path
  24from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
  25
  26import cv2
  27import numpy as np
  28import optuna
  29import pandas as pd
  30import plotly
  31import torch
  32from matplotlib import cm
  33from matplotlib import pyplot as plt
  34from matplotlib import rc
  35from numpy import ndarray
  36from ruamel.yaml import YAML
  37from ruamel.yaml.comments import CommentedMap, CommentedSet
  38from tqdm import tqdm
  39
  40from dlc2action import __version__, options
  41from dlc2action.data.dataset import BehaviorDataset
  42from dlc2action.project.meta import (
  43    DecisionThresholds,
  44    Run,
  45    SavedRuns,
  46    SavedStores,
  47    Searches,
  48    Suggestions,
  49)
  50from dlc2action.task.task_dispatcher import TaskDispatcher
  51from dlc2action.utils import apply_threshold, binarize_data, load_pickle
  52
  53
  54class Project:
  55    """A class to create and maintain the project files + keep track of experiments."""
  56
  57    def __init__(
  58        self,
  59        name: str,
  60        data_type: str = None,
  61        annotation_type: str = "none",
  62        projects_path: str = None,
  63        data_path: Union[str, List] = None,
  64        annotation_path: Union[str, List] = None,
  65        copy: bool = False,
  66    ) -> None:
  67        """Initialize the class.
  68
  69        Parameters
  70        ----------
  71        name : str
  72            name of the project
  73        data_type : str, optional
  74            data type (run Project.data_types() to see available options; has to be provided if the project is being
  75            created)
  76        annotation_type : str, default 'none'
  77            annotation type (run Project.annotation_types() to see available options)
  78        projects_path : str, optional
  79            path to the projects folder (is filled with ~/DLC2Action by default)
  80        data_path : str, optional
  81            path to the folder containing input files for the project (has to be provided if the project is being
  82            created)
  83        annotation_path : str, optional
  84            path to the folder containing annotation files for the project
  85        copy : bool, default False
  86            if True, the files from annotation_path and data_path will be copied to the projects folder;
  87            otherwise they will be moved
  88
  89        """
  90        if projects_path is None:
  91            projects_path = os.path.join(str(Path.home()), "DLC2Action")
  92        if not os.path.exists(projects_path):
  93            os.mkdir(projects_path)
  94        self.project_path = os.path.join(projects_path, name)
  95        self.name = name
  96        self.data_type = data_type
  97        self.annotation_type = annotation_type
  98        self.data_path = data_path
  99        self.annotation_path = annotation_path
 100        if not os.path.exists(self.project_path):
 101            if data_type is None:
 102                raise ValueError(
 103                    "The data_type parameter is necessary when creating a new project!"
 104                )
 105            self._initialize_project(
 106                data_type, annotation_type, data_path, annotation_path, copy
 107            )
 108        else:
 109            self.annotation_type, self.data_type = self._read_types()
 110            if data_type != self.data_type and data_type is not None:
 111                raise ValueError(
 112                    f"The project has already been initialized with data_type={self.data_type}!"
 113                )
 114            if annotation_type != self.annotation_type and annotation_type != "none":
 115                raise ValueError(
 116                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
 117                )
 118            self.annotation_path, data_path = self._read_paths()
 119            if self.data_path is None:
 120                self.data_path = data_path
 121            # if data_path != self.data_path and data_path is not None:
 122            #     raise ValueError(
 123            #         f"The project has already been initialized with data_path={self.data_path}!"
 124            #     )
 125            if annotation_path != self.annotation_path and annotation_path is not None:
 126                raise ValueError(
 127                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
 128                )
 129        self._update_configs()
 130
 131    def _make_prediction(
 132        self,
 133        prediction_name: str,
 134        episode_names: List,
 135        load_epochs: Union[List[int], int] = None,
 136        parameters_update: Dict = None,
 137        data_path: str = None,
 138        file_paths: Set = None,
 139        mode: str = "all",
 140        augment_n: int = 0,
 141        evaluate: bool = False,
 142        task: TaskDispatcher = None,
 143        embedding: bool = False,
 144        annotation_type: str = "none",
 145    ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]:
 146        """Generate a prediction.
 147        Parameters
 148        ----------
 149        prediction_name : str
 150            name of the prediction
 151            episode_names : List
 152            names of the episodes to use for the prediction
 153            load_epochs : Union[List[int],int], optional
 154            epochs to load for each episode; if a single integer is provided, it will be used for all episodes;
 155            if None, the last epochs will be used
 156            parameters_update : Dict, optional
 157            dictionary with parameters to update the task parameters
 158            data_path : str, optional
 159            path to the data folder; if None, the data_path from the project will be used
 160            file_paths : Set, optional
 161            set of file paths to use for the prediction; if None, the data_path will be used
 162            mode : str, default "all
 163            mode of the prediction; can be "train", "val", "test" or "all
 164            augment_n : int, default 0
 165            number of augmentations to apply to the data; if 0, no augmentations are applied
 166            evaluate : bool, default False
 167            if True, the prediction will be evaluated and the results will be saved to the episode meta file
 168            task : TaskDispatcher, optional
 169            task object to use for the prediction; if None, a new task object will be created
 170            embedding : bool, default False
 171            if True, the prediction will be returned as an embedding
 172            annotation_type : str, default "none
 173            type of the annotation to use for the prediction; if "none", the annotation will not be used
 174        Returns
 175        -------
 176        task : TaskDispatcher
 177            task object used for the prediction
 178        parameters : Dict
 179            parameters used for the prediction
 180        mode : str
 181            mode of the prediction
 182        prediction : torch.Tensor
 183            prediction tensor of shape (num_videos, num_behaviors, num_frames)
 184        inference_time : str
 185            time taken for the prediction in the format "HH:MM:SS"
 186        behavior_dict : Dict
 187            dictionary with behavior names and their indices
 188        """
 189
 190        names = []
 191        for episode_name in episode_names:
 192            names += self._episodes().get_runs(episode_name)
 193        if len(names) == 0:
 194            warnings.warn(f"None of the episodes {episode_names} exist!")
 195            names = [None]
 196        if load_epochs is None:
 197            load_epochs = [None for _ in names]
 198        elif isinstance(load_epochs, int):
 199            load_epochs = [load_epochs for _ in names]
 200        assert len(load_epochs) == len(
 201            names
 202        ), f"Length of load_epochs ({len(load_epochs)}) must match the number of episodes ({len(names)})!"
 203        prediction = None
 204        decision_thresholds = None
 205        time_total = 0
 206        behavior_dicts = [
 207            self.get_behavior_dictionary(episode_name) for episode_name in names
 208        ]
 209
 210        if not all(
 211            [
 212                set(d.values()) == set(behavior_dicts[0].values())
 213                for d in behavior_dicts[1:]
 214            ]
 215        ):
 216            raise ValueError(
 217                f"Episodes {episode_names} have different sets of behaviors!"
 218            )
 219        behaviors = list(behavior_dicts[0].values())
 220
 221        for episode_name, load_epoch, behavior_dict in zip(
 222            names, load_epochs, behavior_dicts
 223        ):
 224            print(f"episode {episode_name}")
 225            task, parameters, data_mode = self._make_task_prediction(
 226                prediction_name=prediction_name,
 227                load_episode=episode_name,
 228                parameters_update=parameters_update,
 229                load_epoch=load_epoch,
 230                data_path=data_path,
 231                mode=mode,
 232                file_paths=file_paths,
 233                task=task,
 234                decision_thresholds=decision_thresholds,
 235                annotation_type=annotation_type,
 236            )
 237            # data_mode = "train" if mode == "all" else mode
 238            time_start = time.time()
 239            new_pred = task.predict(
 240                data_mode,
 241                raw_output=True,
 242                apply_primary_function=True,
 243                augment_n=augment_n,
 244                embedding=embedding,
 245            )
 246            indices = [
 247                behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1])
 248            ]
 249            new_pred = new_pred[:, indices, :]
 250            time_end = time.time()
 251            time_total += time_end - time_start
 252            if evaluate:
 253                _, metrics = task.evaluate_prediction(
 254                    new_pred, data=data_mode, indices=indices
 255                )
 256                if mode == "val":
 257                    self._update_episode_metrics(episode_name, metrics)
 258            if prediction is None:
 259                prediction = new_pred
 260            else:
 261                prediction += new_pred
 262            print("\n")
 263        hours = int(time_total // 3600)
 264        time_total -= hours * 3600
 265        minutes = int(time_total // 60)
 266        time_total -= minutes * 60
 267        seconds = int(time_total)
 268        inference_time = f"{hours}:{minutes:02}:{seconds:02}"
 269        prediction /= len(names)
 270        return (
 271            task,
 272            parameters,
 273            data_mode,
 274            prediction,
 275            inference_time,
 276            behavior_dicts[0],
 277        )
 278
 279    def _make_task_prediction(
 280        self,
 281        prediction_name: str,
 282        load_episode: str = None,
 283        parameters_update: Dict = None,
 284        load_epoch: int = None,
 285        data_path: str = None,
 286        annotation_path: str = None,
 287        mode: str = "val",
 288        file_paths: Set = None,
 289        decision_thresholds: List = None,
 290        task: TaskDispatcher = None,
 291        annotation_type: str = "none",
 292    ) -> Tuple[TaskDispatcher, Dict, str]:
 293        """Make a `TaskDispatcher` object that will be used to generate a prediction."""
 294        if parameters_update is None:
 295            parameters_update = {}
 296        parameters_update_second = {}
 297        if mode == "all" or data_path is not None or file_paths is not None:
 298            parameters_update_second["training"] = {
 299                "val_frac": 0,
 300                "test_frac": 0,
 301                "partition_method": "random",
 302                "save_split": False,
 303                "split_path": None,
 304            }
 305            mode = "train"
 306        if decision_thresholds is not None:
 307            if (
 308                len(decision_thresholds)
 309                == self._episode(load_episode).get_num_classes()
 310            ):
 311                parameters_update_second["general"] = {
 312                    "threshold_value": decision_thresholds
 313                }
 314            else:
 315                raise ValueError(
 316                    f"The length of the decision thresholds {decision_thresholds} "
 317                    f"must be equal to the length of the behaviors dictionary "
 318                    f"{self._episode(load_episode).get_behaviors_dict()}"
 319                )
 320        data_param_update = {}
 321        if data_path is not None:
 322            data_param_update = {"data_path": data_path}
 323            if annotation_path is None:
 324                data_param_update["annotation_path"] = data_path
 325        if annotation_path is not None:
 326            data_param_update["annotation_path"] = annotation_path
 327        if file_paths is not None:
 328            data_param_update = {"data_path": None, "file_paths": file_paths}
 329        parameters_update = self._update(parameters_update, {"data": data_param_update})
 330        if data_path is not None or file_paths is not None:
 331            general_update = {
 332                "annotation_type": annotation_type,
 333                "only_load_annotated": False,
 334            }
 335        else:
 336            general_update = {}
 337        parameters_update = self._update(parameters_update, {"general": general_update})
 338        task, parameters = self._make_task(
 339            episode_name=prediction_name,
 340            load_episode=load_episode,
 341            parameters_update=parameters_update,
 342            parameters_update_second=parameters_update_second,
 343            load_epoch=load_epoch,
 344            purpose="prediction",
 345            task=task,
 346        )
 347        return task, parameters, mode
 348
 349    def _make_task_training(
 350        self,
 351        episode_name: str,
 352        load_episode: str = None,
 353        parameters_update: Dict = None,
 354        load_epoch: int = None,
 355        load_search: str = None,
 356        load_parameters: list = None,
 357        round_to_binary: list = None,
 358        load_strict: bool = True,
 359        continuing: bool = False,
 360        task: TaskDispatcher = None,
 361        mask_name: str = None,
 362        throwaway: bool = False,
 363    ) -> Tuple[TaskDispatcher, Dict, str]:
 364        """Make a `TaskDispatcher` object that will be used to generate a prediction."""
 365        if parameters_update is None:
 366            parameters_update = {}
 367        if continuing:
 368            purpose = "continuing"
 369        else:
 370            purpose = "training"
 371        if mask_name is not None:
 372            mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle")
 373        parameters_update_second = {"data": {"real_lens": mask_name}}
 374        if throwaway:
 375            parameters_update = self._update(
 376                parameters_update, {"training": {"normalize": False, "device": "cpu"}}
 377            )
 378        return self._make_task(
 379            episode_name,
 380            load_episode,
 381            parameters_update,
 382            parameters_update_second,
 383            load_epoch,
 384            load_search,
 385            load_parameters,
 386            round_to_binary,
 387            purpose,
 388            task,
 389            load_strict=load_strict,
 390        )
 391
 392    def _make_parameters(
 393        self,
 394        episode_name: str,
 395        load_episode: str = None,
 396        parameters_update: Dict = None,
 397        parameters_update_second: Dict = None,
 398        load_epoch: int = None,
 399        load_search: str = None,
 400        load_parameters: list = None,
 401        round_to_binary: list = None,
 402        purpose: str = "train",
 403        load_strict: bool = True,
 404    ):
 405        """Construct a parameters dictionary."""
 406        if parameters_update is None:
 407            parameters_update = {}
 408        pars_update = deepcopy(parameters_update)
 409        if parameters_update_second is None:
 410            parameters_update_second = {}
 411        if (
 412            purpose == "prediction"
 413            and "model" in pars_update.keys()
 414            and pars_update["general"]["model_name"] != "motionbert"
 415        ):
 416            raise ValueError("Cannot change model parameters after training!")
 417        if purpose in ["continuing", "prediction"] and load_episode is not None:
 418            read_parameters = self._read_parameters()
 419            parameters = self._episodes().load_parameters(load_episode)
 420            parameters["metrics"] = self._update(
 421                read_parameters["metrics"], parameters["metrics"]
 422            )
 423            parameters["ssl"] = self._update(
 424                read_parameters["ssl"], parameters.get("ssl", {})
 425            )
 426        else:
 427            parameters = self._read_parameters()
 428        if "model" in pars_update:
 429            model_params = pars_update.pop("model")
 430        else:
 431            model_params = None
 432        if "features" in pars_update:
 433            feat_params = pars_update.pop("features")
 434        else:
 435            feat_params = None
 436        if "augmentations" in pars_update:
 437            aug_params = pars_update.pop("augmentations")
 438        else:
 439            aug_params = None
 440        parameters = self._update(parameters, pars_update)
 441        if pars_update.get("general", {}).get("model_name") is not None:
 442            model_name = parameters["general"]["model_name"]
 443            parameters["model"] = self._open_yaml(
 444                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
 445            )
 446        if pars_update.get("general", {}).get("feature_extraction") is not None:
 447            feat_name = parameters["general"]["feature_extraction"]
 448            parameters["features"] = self._open_yaml(
 449                os.path.join(
 450                    self.project_path, "config", "features", f"{feat_name}.yaml"
 451                )
 452            )
 453            aug_name = options.extractor_to_transformer[
 454                parameters["general"]["feature_extraction"]
 455            ]
 456            parameters["augmentations"] = self._open_yaml(
 457                os.path.join(
 458                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
 459                )
 460            )
 461        if model_params is not None:
 462            parameters["model"] = self._update(parameters["model"], model_params)
 463        if feat_params is not None:
 464            parameters["features"] = self._update(parameters["features"], feat_params)
 465        if aug_params is not None:
 466            parameters["augmentations"] = self._update(
 467                parameters["augmentations"], aug_params
 468            )
 469        if load_search is not None:
 470            parameters = self._update_with_search(
 471                parameters, load_search, load_parameters, round_to_binary
 472            )
 473        parameters = self._fill(
 474            parameters,
 475            episode_name,
 476            load_episode,
 477            load_epoch=load_epoch,
 478            load_strict=load_strict,
 479            only_load_model=(purpose != "continuing"),
 480            continuing=(purpose in ["prediction", "continuing"]),
 481            enforce_split_parameters=(purpose == "prediction"),
 482        )
 483        parameters = self._update(parameters, parameters_update_second)
 484        return parameters
 485
 486    def _make_task(
 487        self,
 488        episode_name: str,
 489        load_episode: str = None,
 490        parameters_update: Dict = None,
 491        parameters_update_second: Dict = None,
 492        load_epoch: int = None,
 493        load_search: str = None,
 494        load_parameters: list = None,
 495        round_to_binary: list = None,
 496        purpose: str = "train",
 497        task: TaskDispatcher = None,
 498        load_strict: bool = True,
 499    ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]:
 500        """Make a `TaskDispatcher` object.
 501
 502        The task parameters are read from the config files and then updated with the
 503        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 504        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 505        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 506        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 507        data parameters are used.
 508
 509        Parameters
 510        ----------
 511        episode_name : str
 512            the name of the episode
 513        load_episode : str, optional
 514            the (previously run) episode name to load the model from
 515        parameters_update : dict, optional
 516            the dictionary used to update the parameters from the config
 517        parameters_update_second : dict, optional
 518            the dictionary used to update the parameters after the automatic fill-out
 519        load_epoch : int, optional
 520            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 521        load_search : str, optional
 522            the hyperparameter search result to load
 523        load_parameters : list, optional
 524            a list of string names of the parameters to load from load_search (if not provided, all parameters
 525            are loaded)
 526        round_to_binary : list, optional
 527            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 528        purpose : {"train", "continuing", "prediction"}
 529            the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing
 530            the training of an interrupted episode, `"prediction"` for generating a prediction)
 531        task : TaskDispatcher, optional
 532            a pre-existing task; if provided, the method will update the task instead of creating a new one
 533            (this might save time, mainly on dataset loading)
 534
 535        Returns
 536        -------
 537        task : TaskDispatcher
 538            the `TaskDispatcher` instance
 539        parameters : dict
 540            the parameters dictionary that describes the task
 541
 542        """
 543        parameters = self._make_parameters(
 544            episode_name,
 545            load_episode,
 546            parameters_update,
 547            parameters_update_second,
 548            load_epoch,
 549            load_search,
 550            load_parameters,
 551            round_to_binary,
 552            purpose,
 553            load_strict=load_strict,
 554        )
 555        if task is None:
 556            task = TaskDispatcher(parameters)
 557        else:
 558            task.update_task(parameters)
 559        self._save_stores(parameters)
 560        return task, parameters
 561
 562    def get_decision_thresholds(
 563        self,
 564        episode_names: List,
 565        metric_name: str = "f1",
 566        parameters_update: Dict = None,
 567        load_epochs: List = None,
 568        remove_saved_features: bool = False,
 569    ) -> Tuple[List, List, TaskDispatcher]:
 570        """Compute optimal decision thresholds or load them if they have been computed before.
 571
 572        Parameters
 573        ----------
 574        episode_names : List
 575            a list of episode names
 576        metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"}
 577            the metric to optimize
 578        parameters_update : dict, optional
 579            the parameter update dictionary
 580        load_epochs : list, optional
 581            a list of epochs to load (by default last are loaded)
 582        remove_saved_features : bool, default False
 583            if `True`, the dataset will be deleted after the computation
 584
 585        Returns
 586        -------
 587        thresholds : list
 588            a list of float decision threshold values
 589        classes : list
 590            the label names corresponding to the values
 591        task : TaskDispatcher | None
 592            the task used in computation
 593
 594        """
 595        parameters = self._make_parameters(
 596            "_",
 597            episode_names[0],
 598            parameters_update,
 599            {},
 600            load_epochs[0],
 601            purpose="prediction",
 602        )
 603        thresholds = self._thresholds().find_thresholds(
 604            episode_names,
 605            load_epochs,
 606            metric_name,
 607            metric_parameters=parameters["metrics"][metric_name],
 608        )
 609        task = None
 610        behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values())
 611        return thresholds, behaviors, task
 612
 613    def run_episode(
 614        self,
 615        episode_name: str,
 616        load_episode: str = None,
 617        parameters_update: Dict = None,
 618        task: TaskDispatcher = None,
 619        load_epoch: int = None,
 620        load_search: str = None,
 621        load_parameters: list = None,
 622        round_to_binary: list = None,
 623        load_strict: bool = True,
 624        n_seeds: int = 1,
 625        force: bool = False,
 626        suppress_name_check: bool = False,
 627        remove_saved_features: bool = False,
 628        mask_name: str = None,
 629        autostop_metric: str = None,
 630        autostop_interval: int = 50,
 631        autostop_threshold: float = 0.001,
 632        loading_bar: bool = False,
 633        trial: Tuple = None,
 634    ) -> TaskDispatcher:
 635        """Run an episode.
 636
 637        The task parameters are read from the config files and then updated with the
 638        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 639        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 640        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 641        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 642        data parameters are used.
 643
 644        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 645        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 646        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 647        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 648
 649        Parameters
 650        ----------
 651        episode_name : str
 652            the episode name
 653        load_episode : str, optional
 654            the (previously run) episode name to load the model from; if the episode has multiple runs,
 655            the new episode will have the same number of runs, each starting with one of the pre-trained models
 656        parameters_update : dict, optional
 657            the dictionary used to update the parameters from the config files
 658        task : TaskDispatcher, optional
 659            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
 660            a new instance)
 661        load_epoch : int, optional
 662            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 663        load_search : str, optional
 664            the hyperparameter search result to load
 665        load_parameters : list, optional
 666            a list of string names of the parameters to load from load_search (if not provided, all parameters
 667            are loaded)
 668        round_to_binary : list, optional
 669            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 670        load_strict : bool, default True
 671            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
 672            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
 673        n_seeds : int, default 1
 674            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 675            `test_episode#0` and `test_episode#1`
 676        force : bool, default False
 677            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
 678        suppress_name_check : bool, default False
 679            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 680            why they are usually forbidden)
 681        remove_saved_features : bool, default False
 682            if `True`, the dataset will be deleted after training
 683        mask_name : str, optional
 684            the name of the real_lens to apply
 685        autostop_metric : str, optional
 686            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 687        autostop_interval : int, default 50
 688            the number of epochs to average the autostop metric over
 689        autostop_threshold : float, default 0.001
 690            the autostop difference threshold
 691        loading_bar : bool, default False
 692            if `True`, a loading bar will be displayed
 693        trial : tuple, optional
 694            a tuple of (trial, metric) for hyperparameter search
 695
 696        Returns
 697        -------
 698        TaskDispatcher
 699            the `TaskDispatcher` object
 700
 701        """
 702
 703        import gc
 704
 705        gc.collect()
 706        if torch.cuda.is_available():
 707            torch.cuda.empty_cache()
 708
 709        if type(n_seeds) is not int or n_seeds < 1:
 710            raise ValueError(
 711                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
 712            )
 713        if n_seeds > 1 and mask_name is not None:
 714            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
 715        self._check_episode_validity(
 716            episode_name, allow_doublecolon=suppress_name_check, force=force
 717        )
 718        load_runs = self._episodes().get_runs(load_episode)
 719        if len(load_runs) > 1:
 720            task = self.run_episodes(
 721                episode_names=[
 722                    f'{episode_name}#{run.split("#")[-1]}' for run in load_runs
 723                ],
 724                load_episodes=load_runs,
 725                parameters_updates=[parameters_update for _ in load_runs],
 726                load_epochs=[load_epoch for _ in load_runs],
 727                load_searches=[load_search for _ in load_runs],
 728                load_parameters=[load_parameters for _ in load_runs],
 729                round_to_binary=[round_to_binary for _ in load_runs],
 730                load_strict=[load_strict for _ in load_runs],
 731                suppress_name_check=True,
 732                force=force,
 733                remove_saved_features=False,
 734            )
 735            if remove_saved_features:
 736                self._remove_stores(
 737                    {
 738                        "general": task.general_parameters,
 739                        "data": task.data_parameters,
 740                        "features": task.feature_parameters,
 741                    }
 742                )
 743            if n_seeds > 1:
 744                warnings.warn(
 745                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
 746                )
 747        elif n_seeds > 1:
 748
 749            self.run_episodes(
 750                episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)],
 751                load_episodes=[load_episode for _ in range(n_seeds)],
 752                parameters_updates=[parameters_update for _ in range(n_seeds)],
 753                load_epochs=[load_epoch for _ in range(n_seeds)],
 754                load_searches=[load_search for _ in range(n_seeds)],
 755                load_parameters=[load_parameters for _ in range(n_seeds)],
 756                round_to_binary=[round_to_binary for _ in range(n_seeds)],
 757                load_strict=[load_strict for _ in range(n_seeds)],
 758                suppress_name_check=True,
 759                force=force,
 760                remove_saved_features=remove_saved_features,
 761            )
 762        else:
 763            print(f"TRAINING {episode_name}")
 764            try:
 765                task, parameters = self._make_task_training(
 766                    episode_name,
 767                    load_episode,
 768                    parameters_update,
 769                    load_epoch,
 770                    load_search,
 771                    load_parameters,
 772                    round_to_binary,
 773                    continuing=False,
 774                    task=task,
 775                    mask_name=mask_name,
 776                    load_strict=load_strict,
 777                )
 778                self._save_episode(
 779                    episode_name,
 780                    parameters,
 781                    task.behaviors_dict(),
 782                    norm_stats=task.get_normalization_stats(),
 783                )
 784                time_start = time.time()
 785                if trial is not None:
 786                    trial, metric = trial
 787                else:
 788                    trial, metric = None, None
 789                logs = task.train(
 790                    autostop_metric=autostop_metric,
 791                    autostop_interval=autostop_interval,
 792                    autostop_threshold=autostop_threshold,
 793                    loading_bar=loading_bar,
 794                    trial=trial,
 795                    optimized_metric=metric,
 796                )
 797                time_end = time.time()
 798                time_total = time_end - time_start
 799                hours = int(time_total // 3600)
 800                time_total -= hours * 3600
 801                minutes = int(time_total // 60)
 802                time_total -= minutes * 60
 803                seconds = int(time_total)
 804                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 805                self._update_episode_results(episode_name, logs, training_time)
 806                if remove_saved_features:
 807                    self._remove_stores(parameters)
 808                print("\n")
 809                return task
 810
 811            except Exception as e:
 812                if isinstance(e, optuna.exceptions.TrialPruned):
 813                    raise e
 814                else:
 815                    # if str(e) != f"The {episode_name} episode name is already in use!":
 816                    #     self.remove_episode(episode_name)
 817                    raise RuntimeError(f"Episode {episode_name} could not run")
 818
 819    def run_episodes(
 820        self,
 821        episode_names: List,
 822        load_episodes: List = None,
 823        parameters_updates: List = None,
 824        load_epochs: List = None,
 825        load_searches: List = None,
 826        load_parameters: List = None,
 827        round_to_binary: List = None,
 828        load_strict: List = None,
 829        force: bool = False,
 830        suppress_name_check: bool = False,
 831        remove_saved_features: bool = False,
 832    ) -> TaskDispatcher:
 833        """Run multiple episodes in sequence (and re-use previously loaded information).
 834
 835        For each episode, the task parameters are read from the config files and then updated with the
 836        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
 837        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 838        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 839        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 840        data parameters are used.
 841
 842        Parameters
 843        ----------
 844        episode_names : list
 845            a list of strings of episode names
 846        load_episodes : list, optional
 847            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
 848            the new episode will have the same number of runs, each starting with one of the pre-trained models
 849        parameters_updates : list, optional
 850            a list of dictionaries used to update the parameters from the config
 851        load_epochs : list, optional
 852            a list of integers used to specify the epoch to load (if load_episodes is not None)
 853        load_searches : list, optional
 854            a list of strings of hyperparameter search results to load
 855        load_parameters : list, optional
 856            a list of lists of string names of the parameters to load from the searches
 857        round_to_binary : list, optional
 858            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 859        load_strict : list, optional
 860            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
 861            the corresponding episode and differences in parameter name lists and
 862            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
 863            every episode)
 864        force : bool, default False
 865            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
 866        suppress_name_check : bool, default False
 867            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 868            why they are usually forbidden)
 869        remove_saved_features : bool, default False
 870            if `True`, the dataset will be deleted after training
 871
 872        Returns
 873        -------
 874        TaskDispatcher
 875            the task dispatcher object
 876
 877        """
 878        task = None
 879        if load_searches is None:
 880            load_searches = [None for _ in episode_names]
 881        if load_episodes is None:
 882            load_episodes = [None for _ in episode_names]
 883        if parameters_updates is None:
 884            parameters_updates = [None for _ in episode_names]
 885        if load_parameters is None:
 886            load_parameters = [None for _ in episode_names]
 887        if load_epochs is None:
 888            load_epochs = [None for _ in episode_names]
 889        if load_strict is None:
 890            load_strict = [True for _ in episode_names]
 891        for (
 892            parameters_update,
 893            episode_name,
 894            load_episode,
 895            load_epoch,
 896            load_search,
 897            load_parameters_list,
 898            load_strict_value,
 899        ) in zip(
 900            parameters_updates,
 901            episode_names,
 902            load_episodes,
 903            load_epochs,
 904            load_searches,
 905            load_parameters,
 906            load_strict,
 907        ):
 908            task = self.run_episode(
 909                episode_name,
 910                load_episode,
 911                parameters_update,
 912                task,
 913                load_epoch,
 914                load_search,
 915                load_parameters_list,
 916                round_to_binary,
 917                load_strict_value,
 918                suppress_name_check=suppress_name_check,
 919                force=force,
 920                remove_saved_features=remove_saved_features,
 921            )
 922        return task
 923
 924    def continue_episode(
 925        self,
 926        episode_name: str,
 927        num_epochs: int = None,
 928        task: TaskDispatcher = None,
 929        n_seeds: int = 1,
 930        remove_saved_features: bool = False,
 931        device: str = "cuda",
 932        num_cpus: int = None,
 933    ) -> TaskDispatcher:
 934        """Load an older episode and continue running from the latest checkpoint.
 935
 936        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
 937
 938        Parameters
 939        ----------
 940        episode_name : str
 941            the name of the episode to continue
 942        num_epochs : int, optional
 943            the new number of epochs
 944        task : TaskDispatcher, optional
 945            a pre-existing task; if provided, the method will update the task instead of creating a new one
 946            (this might save time, mainly on dataset loading)
 947        n_seeds : int, default 1
 948            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 949            `test_episode#0` and `test_episode#1`
 950        remove_saved_features : bool, default False
 951            if `True`, pre-computed features will be deleted after the run
 952        device : str, default "cuda"
 953            the torch device to use
 954        num_cpus : int, optional
 955            the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used
 956
 957        Returns
 958        -------
 959        TaskDispatcher
 960            the task dispatcher
 961
 962        """
 963        runs = self._episodes().get_runs(episode_name)
 964        for run in runs:
 965            print(f"TRAINING {run}")
 966            if num_epochs is None and not self._episode(run).unfinished():
 967                continue
 968            parameters_update = {
 969                "training": {
 970                    "num_epochs": num_epochs,
 971                    "device": device,
 972                },
 973                "general": {"num_cpus": num_cpus},
 974            }
 975            task, parameters = self._make_task_training(
 976                run,
 977                load_episode=run,
 978                parameters_update=parameters_update,
 979                continuing=True,
 980                task=task,
 981            )
 982            time_start = time.time()
 983            logs = task.train()
 984            time_end = time.time()
 985            old_time = self._training_time(run)
 986            if not np.isnan(old_time):
 987                time_end += old_time
 988                time_total = time_end - time_start
 989                hours = int(time_total // 3600)
 990                time_total -= hours * 3600
 991                minutes = int(time_total // 60)
 992                time_total -= minutes * 60
 993                seconds = int(time_total)
 994                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 995            else:
 996                training_time = np.nan
 997            self._save_episode(
 998                run,
 999                parameters,
1000                task.behaviors_dict(),
1001                suppress_validation=True,
1002                training_time=training_time,
1003                norm_stats=task.get_normalization_stats(),
1004            )
1005            self._update_episode_results(run, logs)
1006            print("\n")
1007        if len(runs) < n_seeds:
1008            for i in range(len(runs), n_seeds):
1009                self.run_episode(
1010                    f"{episode_name}#{i}",
1011                    parameters_update=self._episodes().load_parameters(runs[0]),
1012                    task=task,
1013                    suppress_name_check=True,
1014                )
1015        if remove_saved_features:
1016            self._remove_stores(parameters)
1017        return task
1018
1019    def run_default_hyperparameter_search(
1020        self,
1021        search_name: str,
1022        model_name: str,
1023        metric: str = "f1",
1024        best_n: int = 3,
1025        direction: str = "maximize",
1026        load_episode: str = None,
1027        load_epoch: int = None,
1028        load_strict: bool = True,
1029        prune: bool = True,
1030        force: bool = False,
1031        remove_saved_features: bool = False,
1032        overlap: float = 0,
1033        num_epochs: int = 50,
1034        test_frac: float = None,
1035        n_trials=150,
1036        batch_size=32,
1037    ):
1038        """Run an optuna hyperparameter search with default parameters for a model.
1039
1040        For the vast majority of cases, optimizing the default parameters should be enough.
1041        Check out `dlc2action.options.model_hyperparameters` for the lists of parameters.
1042        There are also options to set overlap, test fraction and number of epochs parameters for the search without
1043        modifying the project config files. However, if you want something more complex, look into
1044        `Project.run_hyperparameter_search`.
1045
1046        The task parameters are read from the config files and updated with the parameters_update dictionary.
1047        The model can be either initialized from scratch or loaded from a previously run episode.
1048        For each trial, the objective metric is averaged over a few best epochs.
1049
1050        Parameters
1051        ----------
1052        search_name : str
1053            the name of the search to store it in the meta files and load in run_episode
1054        model_name : str
1055            the name
1056        metric : str
1057            the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to
1058            `"none"` in the config files, it will be reset to `"macro"` for the search
1059        best_n : int, default 1
1060            the number of epochs to average the metric; if 0, the last value is taken
1061        direction : {'maximize', 'minimize'}
1062            optimization direction
1063        load_episode : str, optional
1064            the name of the episode to load the model from
1065        load_epoch : int, optional
1066            the epoch to load the model from (if not provided, the last checkpoint is used)
1067        load_strict : bool, default True
1068            if `True`, the model will be loaded only if the parameters match exactly
1069        prune : bool, default False
1070            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1071            (with optuna HyperBand pruner)
1072        force : bool, default False
1073            if `True`, existing searches with the same name will be overwritten
1074        remove_saved_features : bool, default False
1075            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1076        overlap : float, default 0
1077            the overlap to use for the search
1078        num_epochs : int, default 50
1079            the number of epochs to use for the search
1080        test_frac : float, optional
1081            the test fraction to use for the search
1082        n_trials : int, default 150
1083            the number of trials to run
1084        batch_size : int, default 32
1085            the batch size to use for the search
1086
1087        Returns
1088        -------
1089        best_parameters : dict
1090            a dictionary of best parameters
1091
1092        """
1093        if model_name not in options.model_hyperparameters:
1094            raise ValueError(
1095                f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()"
1096            )
1097        pars = {
1098            "general": {"overlap": overlap, "model_name": model_name},
1099            "training": {"num_epochs": num_epochs, "batch_size": batch_size},
1100        }
1101        if test_frac is not None:
1102            pars["training"]["test_frac"] = test_frac
1103        if not metric.split("_")[-1].isnumeric():
1104            project_pars = self._read_parameters()
1105            if project_pars["metrics"][metric].get("average") == "none":
1106                pars["metrics"] = {metric: {"average": "macro"}}
1107        return self.run_hyperparameter_search(
1108            search_name=search_name,
1109            search_space=options.model_hyperparameters[model_name],
1110            metric=metric,
1111            n_trials=n_trials,
1112            best_n=best_n,
1113            parameters_update=pars,
1114            direction=direction,
1115            load_episode=load_episode,
1116            load_epoch=load_epoch,
1117            load_strict=load_strict,
1118            prune=prune,
1119            force=force,
1120            remove_saved_features=remove_saved_features,
1121        )
1122
1123    def run_hyperparameter_search(
1124        self,
1125        search_name: str,
1126        search_space: Dict,
1127        metric: str = "f1",
1128        n_trials: int = 20,
1129        best_n: int = 1,
1130        parameters_update: Dict = None,
1131        direction: str = "maximize",
1132        load_episode: str = None,
1133        load_epoch: int = None,
1134        load_strict: bool = True,
1135        prune: bool = False,
1136        force: bool = False,
1137        remove_saved_features: bool = False,
1138        make_plots: bool = True,
1139    ) -> Dict:
1140        """Run an optuna hyperparameter search.
1141
1142        For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`.
1143
1144        To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is
1145        a dictionary where keys are model names and values are default search spaces.
1146
1147        The task parameters are read from the config files and updated with the parameters_update dictionary.
1148        The model can be either initialized from scratch or loaded from a previously run episode.
1149        For each trial, the objective metric is averaged over a few best epochs.
1150
1151        Parameters
1152        ----------
1153        search_name : str
1154            the name of the search to store it in the meta files and load in run_episode
1155        search_space : dict
1156            a dictionary representing the search space; of this general structure:
1157            {'group/param_name': ('float/int/float_log/int_log', start, end),
1158            'group/param_name': ('categorical', [choices])}, e.g.
1159            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
1160            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])};
1161        metric : str, default f1
1162            the metric to maximize/minimize (see direction)
1163        n_trials : int, default 20
1164            the number of optimization trials to run
1165        best_n : int, default 1
1166            the number of epochs to average the metric; if 0, the last value is taken
1167        parameters_update : dict, optional
1168            the parameters update dictionary
1169        direction : {'maximize', 'minimize'}
1170            optimization direction
1171        load_episode : str, optional
1172            the name of the episode to load the model from
1173        load_epoch : int, optional
1174            the epoch to load the model from (if not provided, the last checkpoint is used)
1175        load_strict : bool, default True
1176            if `True`, the model will be loaded only if the parameters match exactly
1177        prune : bool, default False
1178            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1179            (with optuna HyperBand pruner)
1180        force : bool, default False
1181            if `True`, existing searches with the same name will be overwritten
1182        remove_saved_features : bool, default False
1183            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1184
1185        Returns
1186        -------
1187        dict
1188            a dictionary of best parameters
1189
1190        """
1191        self._check_search_validity(search_name, force=force)
1192        print(f"SEARCH {search_name}")
1193        self.remove_episode(f"_{search_name}")
1194        if parameters_update is None:
1195            parameters_update = {}
1196        parameters_update = self._update(
1197            parameters_update, {"general": {"metric_functions": {metric}}}
1198        )
1199        parameters = self._make_parameters(
1200            f"_{search_name}",
1201            load_episode,
1202            parameters_update,
1203            parameters_update_second={"training": {"model_save_path": None}},
1204            load_epoch=load_epoch,
1205            load_strict=load_strict,
1206        )
1207        task = None
1208
1209        if prune:
1210            pruner = optuna.pruners.HyperbandPruner()
1211        else:
1212            pruner = optuna.pruners.NopPruner()
1213        study = optuna.create_study(direction=direction, pruner=pruner)
1214        runner = _Runner(
1215            search_space=search_space,
1216            load_episode=load_episode,
1217            load_epoch=load_epoch,
1218            metric=metric,
1219            average=best_n,
1220            task=task,
1221            remove_saved_features=remove_saved_features,
1222            project=self,
1223            search_name=search_name,
1224        )
1225        study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials)
1226        if make_plots:
1227            search_path = self._search_path(search_name)
1228            os.mkdir(search_path)
1229            fig = optuna.visualization.plot_contour(study)
1230            plotly.offline.plot(
1231                fig, filename=os.path.join(search_path, f"{search_name}_contour.html")
1232            )
1233            fig = optuna.visualization.plot_param_importances(study)
1234            plotly.offline.plot(
1235                fig,
1236                filename=os.path.join(search_path, f"{search_name}_importances.html"),
1237            )
1238        best_params = study.best_params
1239        best_value = study.best_value
1240        if best_value == 0 or best_value == float("inf"):
1241            raise ValueError(
1242                f"Best metric value is {best_value}, check your partition method and make sure that all behaviors are present in the validation set!"
1243            )
1244        self._save_search(
1245            search_name,
1246            parameters,
1247            n_trials,
1248            best_params,
1249            best_value,
1250            metric,
1251            search_space,
1252        )
1253        self.remove_episode(f"_{search_name}")
1254        runner.clean()
1255        print(f"best parameters: {best_params}")
1256        print("\n")
1257        return best_params
1258
1259    def run_prediction(
1260        self,
1261        prediction_name: str,
1262        episode_names: List,
1263        load_epochs: List = None,
1264        parameters_update: Dict = None,
1265        augment_n: int = 10,
1266        data_path: str = None,
1267        mode: str = "all",
1268        file_paths: Set = None,
1269        remove_saved_features: bool = False,
1270        frame_number_map_file: str = None,
1271        force: bool = False,
1272        embedding: bool = False,
1273    ) -> None:
1274        """Load models from previously run episodes to generate a prediction.
1275
1276        The probabilities predicted by the models are averaged.
1277        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1278        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1279        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1280        are the prediction arrays.
1281
1282        Parameters
1283        ----------
1284        prediction_name : str
1285            the name of the prediction
1286        episode_names : list
1287            a list of string episode names to load the models from
1288        load_epochs : list or int, optional
1289            a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes
1290        parameters_update : dict, optional
1291            a dictionary of parameter updates
1292        augment_n : int, default 10
1293            the number of augmentations to average over
1294        data_path : str, optional
1295            the data path to run the prediction for
1296        mode : {'all', 'test', 'val', 'train'}
1297            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1298        file_paths : set, optional
1299            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1300            for
1301        remove_saved_features : bool, default False
1302            if `True`, pre-computed features will be deleted
1303        submission : bool, default False
1304            if `True`, a MABe-22 style submission file is generated
1305        frame_number_map_file : str, optional
1306            path to the frame number map file
1307        force : bool, default False
1308            if `True`, existing prediction with this name will be overwritten
1309        embedding : bool, default False
1310            if `True`, the prediction is made for the embedding task
1311
1312        """
1313        self._check_prediction_validity(prediction_name, force=force)
1314        print(f"PREDICTION {prediction_name}")
1315        task, parameters, mode, prediction, inference_time, behavior_dict = (
1316            self._make_prediction(
1317                prediction_name,
1318                episode_names,
1319                load_epochs,
1320                parameters_update,
1321                data_path,
1322                file_paths,
1323                mode,
1324                augment_n,
1325                evaluate=False,
1326                embedding=embedding,
1327            )
1328        )
1329        predicted = task.dataset(mode).generate_full_length_prediction(prediction)
1330
1331        if remove_saved_features:
1332            self._remove_stores(parameters)
1333
1334        self._save_prediction(
1335            prediction_name,
1336            predicted,
1337            parameters,
1338            task,
1339            mode,
1340            embedding,
1341            inference_time,
1342            behavior_dict,
1343        )
1344        print("\n")
1345
1346    def evaluate_prediction(
1347        self,
1348        prediction_name: str,
1349        parameters_update: Dict = None,
1350        data_path: str = None,
1351        annotation_path: str = None,
1352        file_paths: Set = None,
1353        mode: str = None,
1354        remove_saved_features: bool = False,
1355        annotation_type: str = "none",
1356        num_classes: int = None,  # Set when using data_path
1357    ) -> Tuple[float, dict]:
1358        """Make predictions and evaluate them
1359        inputs:
1360            prediction_name (str): the name of the prediction
1361            parameters_update (dict): a dictionary of parameter updates
1362            data_path (str): the data path to run the prediction for
1363            annotation_path (str): the annotation path to run the prediction for
1364            file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for
1365            mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1366            remove_saved_features (bool): if `True`, pre-computed features will be deleted
1367            annotation_type (str): the type of annotation to use for evaluation
1368            num_classes (int): the number of classes in the dataset, must be set with data_path
1369        outputs:
1370            results (dict): a dictionary of average values of metric functions
1371        """
1372
1373        prediction_path = os.path.join(
1374            self.project_path, "results", "predictions", f"{prediction_name}"
1375        )
1376        prediction_dict = {}
1377        for prediction_file_path in [
1378            os.path.join(prediction_path, i) for i in os.listdir(prediction_path)
1379        ]:
1380            with open(os.path.join(prediction_file_path), "rb") as f:
1381                prediction = pickle.load(f)
1382            video_id = os.path.basename(prediction_file_path).split(
1383                "_" + prediction_name
1384            )[0]
1385            prediction_dict[video_id] = prediction
1386        if parameters_update is None:
1387            parameters_update = {}
1388        parameters_update = self._update(
1389            self._predictions().load_parameters(prediction_name), parameters_update
1390        )
1391        parameters_update.pop("model")
1392        if not data_path is None:
1393            assert (
1394                not num_classes is None
1395            ), "num_classes must be provided if data_path is provided"
1396            parameters_update["general"]["num_classes"] = num_classes + int(
1397                parameters_update["general"]["exclusive"]
1398            )
1399        task, parameters, mode = self._make_task_prediction(
1400            "_",
1401            load_episode=None,
1402            parameters_update=parameters_update,
1403            data_path=data_path,
1404            annotation_path=annotation_path,
1405            file_paths=file_paths,
1406            mode=mode,
1407            annotation_type=annotation_type,
1408        )
1409        results = task.evaluate_prediction(prediction_dict, data=mode)
1410        if remove_saved_features:
1411            self._remove_stores(parameters)
1412        results = Project._reformat_results(
1413            results[1],
1414            task.behaviors_dict(),
1415            exclusive=task.general_parameters["exclusive"],
1416        )
1417        return results
1418
1419    def evaluate(
1420        self,
1421        episode_names: List,
1422        load_epochs: List = None,
1423        augment_n: int = 0,
1424        data_path: str = None,
1425        file_paths: Set = None,
1426        mode: str = None,
1427        parameters_update: Dict = None,
1428        multiple_episode_policy: str = "average",
1429        remove_saved_features: bool = False,
1430        skip_updating_meta: bool = True,
1431        annotation_type: str = "none",
1432    ) -> Dict:
1433        """Load one or several models from previously run episodes to make an evaluation.
1434
1435        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1436
1437        Parameters
1438        ----------
1439        episode_names : list
1440            a list of string episode names to load the models from
1441        load_epochs : list, optional
1442            a list of integer epoch indices to load the model from; if None, the last ones are used
1443        augment_n : int, default 0
1444            the number of augmentations to average over
1445        data_path : str, optional
1446            the data path to run the prediction for
1447        file_paths : set, optional
1448            a set of files to run the prediction for
1449        mode : {'test', 'val', 'train', 'all'}
1450            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1451            by default 'test' if test subset is not empty and 'val' otherwise)
1452        parameters_update : dict, optional
1453            a dictionary with parameter updates (cannot change model parameters)
1454        multiple_episode_policy : {'average', 'statistics'}
1455            the policy to use when multiple episodes are provided
1456        remove_saved_features : bool, default False
1457            if `True`, the dataset will be deleted
1458        skip_updating_meta : bool, default True
1459            if `True`, the meta file will not be updated with the computed metrics
1460
1461        Returns
1462        -------
1463        metric : dict
1464            a dictionary of average values of metric functions
1465
1466        """
1467        names = []
1468        for episode_name in episode_names:
1469            names += self._episodes().get_runs(episode_name)
1470        if len(set(episode_names)) == 1:
1471            print(f"EVALUATION {episode_names[0]}")
1472        else:
1473            print(f"EVALUATION {episode_names}")
1474        if len(names) > 1:
1475            evaluate = True
1476        else:
1477            evaluate = False
1478        if multiple_episode_policy == "average":
1479            task, parameters, mode, prediction, inference_time, behavior_dict = (
1480                self._make_prediction(
1481                    "_",
1482                    episode_names,
1483                    load_epochs,
1484                    parameters_update,
1485                    mode=mode,
1486                    data_path=data_path,
1487                    file_paths=file_paths,
1488                    augment_n=augment_n,
1489                    evaluate=evaluate,
1490                    annotation_type=annotation_type,
1491                )
1492            )
1493            print("EVALUATE PREDICTION:")
1494            indices = [
1495                list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict))
1496            ]
1497            _, results = task.evaluate_prediction(
1498                prediction, data=mode, indices=indices
1499            )
1500            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1501                self._update_episode_metrics(names[0], results)
1502            results = Project._reformat_results(
1503                results,
1504                behavior_dict,
1505                exclusive=task.general_parameters["exclusive"],
1506            )
1507
1508        elif multiple_episode_policy == "statistics":
1509            values = defaultdict(lambda: [])
1510            task = None
1511            for name in names:
1512                (
1513                    task,
1514                    parameters,
1515                    mode,
1516                    prediction,
1517                    inference_time,
1518                    behavior_dict,
1519                ) = self._make_prediction(
1520                    "_",
1521                    [name],
1522                    load_epochs,
1523                    parameters_update,
1524                    mode=mode,
1525                    data_path=data_path,
1526                    file_paths=file_paths,
1527                    augment_n=augment_n,
1528                    evaluate=evaluate,
1529                    task=task,
1530                )
1531                _, metrics = task.evaluate_prediction(
1532                    prediction, data=mode, indices=list(behavior_dict.keys())
1533                )
1534                for name, value in metrics.items():
1535                    values[name].append(value)
1536                if mode == "val" and not skip_updating_meta:
1537                    self._update_episode_metrics(name, metrics)
1538            results = defaultdict(lambda: {})
1539            mean_string = ""
1540            std_string = ""
1541            for key, value_list in values.items():
1542                results[key]["mean"] = np.mean(value_list)
1543                results[key]["std"] = np.std(value_list)
1544                results[key]["all"] = value_list
1545                mean_string += f"{key} {np.mean(value_list):.3f}, "
1546                std_string += f"{key} {np.std(value_list):.3f}, "
1547            print("MEAN:")
1548            print(mean_string)
1549            print("STD:")
1550            print(std_string)
1551        else:
1552            raise ValueError(
1553                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1554                f"from ['average', 'statistics']"
1555            )
1556        if len(names) > 0 and remove_saved_features:
1557            self._remove_stores(parameters)
1558        print(f"Inference time: {inference_time}")
1559        print("\n")
1560        return results
1561
1562    def run_suggestion(
1563        self,
1564        suggestions_name: str,
1565        error_episode: str = None,
1566        error_load_epoch: int = None,
1567        error_class: str = None,
1568        suggestions_prediction: str = None,
1569        suggestion_episodes: List = [None],
1570        suggestion_load_epoch: int = None,
1571        suggestion_classes: List = None,
1572        error_threshold: float = 0.5,
1573        error_threshold_diff: float = 0.1,
1574        error_hysteresis: bool = False,
1575        suggestion_threshold: Union[float, List] = 0.5,
1576        suggestion_threshold_diff: Union[float, List] = 0.1,
1577        suggestion_hysteresis: Union[bool, List] = True,
1578        min_frames_suggestion: int = 10,
1579        min_frames_al: int = 30,
1580        visibility_min_score: float = 0,
1581        visibility_min_frac: float = 0.7,
1582        augment_n: int = 0,
1583        exclude_classes: List = None,
1584        exclude_threshold: Union[float, List] = 0.6,
1585        exclude_threshold_diff: Union[float, List] = 0.1,
1586        exclude_hysteresis: Union[bool, List] = False,
1587        include_classes: List = None,
1588        include_threshold: Union[float, List] = 0.4,
1589        include_threshold_diff: Union[float, List] = 0.1,
1590        include_hysteresis: Union[bool, List] = False,
1591        data_path: str = None,
1592        file_paths: Set = None,
1593        parameters_update: Dict = None,
1594        mode: str = "all",
1595        force: bool = False,
1596        remove_saved_features: bool = False,
1597        cut_annotated: bool = False,
1598        background_threshold: float = None,
1599    ) -> None:
1600        """Create active learning and suggestion files.
1601
1602        Generate predictions with the error and suggestion model and use them to create
1603        suggestion files for the labeling interface. Those files will render as suggested labels
1604        at intervals with high pose estimation quality. Quality here is defined by probability of error
1605        (predicted by the error model) and visibility parameters.
1606
1607        If `error_episode` or `exclude_classes` is not `None`,
1608        an active learning file will be created as well (with frames with high predicted probability of classes
1609        from `exclude_classes` and/or errors excluded from the active learning intervals).
1610
1611        In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals)
1612        you can apply one of three methods.
1613
1614        - **Simple threshold**
1615
1616            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold`
1617            parameter to $\alpha$.
1618            In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will
1619            be considered labeled.
1620
1621        - **Hysteresis threshold**
1622
1623            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1624            parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$.
1625            Now intervals will be marked with a label if the probability of that label for all frames is higher
1626            than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$.
1627
1628        - **Max hysteresis threshold**
1629
1630            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1631            parameter to $\alpha$ and the `threshold_diff` parameter to `None`.
1632            With this combination intervals are marked with a label if that label is more likely than any other
1633            for all frames in this interval and at for at least one of those frames its probability is higher than
1634            $\alpha$.
1635
1636        Parameters
1637        ----------
1638        suggestions_name : str
1639            the name of the suggestions
1640        error_episode : str, optional
1641            the name of the episode where the error model should be loaded from
1642        error_load_epoch : int, optional
1643            the epoch the error model should be loaded from
1644        error_class : str, optional
1645            the name of the error class (in `error_episode`)
1646        suggestions_prediction : str, optional
1647            the name of the predictions that should be used for the suggestion model
1648        suggestion_episodes : list, optional
1649            the names of the episodes where the suggestion models should be loaded from
1650        suggestion_load_epoch : int, optional
1651            the epoch the suggestion model should be loaded from
1652        suggestion_classes : list, optional
1653            a list of string names of the classes that should be suggested (in `suggestion_episode`)
1654        error_threshold : float, default 0.5
1655            the hard threshold for error prediction
1656        error_threshold_diff : float, default 0.1
1657            the difference between soft and hard thresholds for error prediction (in case hysteresis is used)
1658        error_hysteresis : bool, default False
1659            if True, hysteresis is used for error prediction
1660        suggestion_threshold : float | list, default 0.5
1661            the hard threshold for class prediction (use a list to set different rules for different classes)
1662        suggestion_threshold_diff : float | list, default 0.1
1663            the difference between soft and hard thresholds for class prediction (in case hysteresis is used;
1664            use a list to set different rules for different classes)
1665        suggestion_hysteresis : bool | list, default True
1666            if True, hysteresis is used for class prediction (use a list to set different rules for different classes)
1667        min_frames_suggestion : int, default 10
1668            only actions longer than this number of frames will be suggested
1669        min_frames_al : int, default 30
1670            only active learning intervals longer than this number of frames will be suggested
1671        visibility_min_score : float, default 0
1672            the minimum visibility score for visibility filtering
1673        visibility_min_frac : float, default 0.7
1674            the minimum fraction of visible frames for visibility filtering
1675        augment_n : int, default 10
1676            the number of augmentations to average the predictions over
1677        exclude_classes : list, optional
1678            a list of string names of classes that should be excluded from the active learning intervals
1679        exclude_threshold : float | list, default 0.6
1680            the hard threshold for excluded class prediction (use a list to set different rules for different classes)
1681        exclude_threshold_diff : float | list, default 0.1
1682            the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used)
1683        exclude_hysteresis : bool | list, default False
1684            if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes)
1685        include_classes : list, optional
1686            a list of string names of classes that should be included into the active learning intervals
1687        include_threshold : float | list, default 0.6
1688            the hard threshold for included class prediction (use a list to set different rules for different classes)
1689        include_threshold_diff : float | list, default 0.1
1690            the difference between soft and hard thresholds for included class prediction (in case hysteresis is used)
1691        include_hysteresis : bool | list, default False
1692            if True, hysteresis is used for included class prediction (use a list to set different rules for different classes)
1693        data_path : str, optional
1694            the data path to run the prediction for
1695        file_paths : set, optional
1696            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1697            for
1698        parameters_update : dict, optional
1699            the parameters update dictionary
1700        mode : {'all', 'test', 'val', 'train'}
1701            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1702        force : bool, default False
1703            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
1704        remove_saved_features : bool, default False
1705            if `True`, the dataset will be deleted.
1706        cut_annotated : bool, default False
1707            if `True`, annotated frames will be cut from the suggestions
1708        background_threshold : float, default 0.5
1709            the threshold for background prediction
1710
1711        """
1712        self._check_suggestions_validity(suggestions_name, force=force)
1713        if any([x is None for x in suggestion_episodes]):
1714            suggestion_episodes = None
1715        if error_episode is None and (
1716            suggestion_episodes is None and suggestions_prediction is None
1717        ):
1718            raise ValueError(
1719                "Both error_episode and suggestion_episode parameters cannot be None at the same time"
1720            )
1721        print(f"SUGGESTION {suggestions_name}")
1722        task = None
1723        if suggestion_classes is None:
1724            suggestion_classes = []
1725        if exclude_classes is None:
1726            exclude_classes = []
1727        if include_classes is None:
1728            include_classes = []
1729        if isinstance(suggestion_threshold, list):
1730            if len(suggestion_threshold) != len(suggestion_classes):
1731                raise ValueError(
1732                    "The suggestion_threshold parameter has to be either a float value or a list of "
1733                    f"float values of the same length as suggestion_classes (got a list of length "
1734                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1735                )
1736        else:
1737            suggestion_threshold = [suggestion_threshold for _ in suggestion_classes]
1738        if isinstance(suggestion_threshold_diff, list):
1739            if len(suggestion_threshold_diff) != len(suggestion_classes):
1740                raise ValueError(
1741                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1742                    f"float values of the same length as suggestion_classes (got a list of length "
1743                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1744                )
1745        else:
1746            suggestion_threshold_diff = [
1747                suggestion_threshold_diff for _ in suggestion_classes
1748            ]
1749        if isinstance(suggestion_hysteresis, list):
1750            if len(suggestion_hysteresis) != len(suggestion_classes):
1751                raise ValueError(
1752                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1753                    f"float values of the same length as suggestion_classes (got a list of length "
1754                    f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)"
1755                )
1756        else:
1757            suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes]
1758        if isinstance(exclude_threshold, list):
1759            if len(exclude_threshold) != len(exclude_classes):
1760                raise ValueError(
1761                    "The exclude_threshold parameter has to be either a float value or a list of "
1762                    f"float values of the same length as exclude_classes (got a list of length "
1763                    f"{len(exclude_threshold)} for {len(exclude_classes)} classes)"
1764                )
1765        else:
1766            exclude_threshold = [exclude_threshold for _ in exclude_classes]
1767        if isinstance(exclude_threshold_diff, list):
1768            if len(exclude_threshold_diff) != len(exclude_classes):
1769                raise ValueError(
1770                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1771                    f"float values of the same length as exclude_classes (got a list of length "
1772                    f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)"
1773                )
1774        else:
1775            exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes]
1776        if isinstance(exclude_hysteresis, list):
1777            if len(exclude_hysteresis) != len(exclude_classes):
1778                raise ValueError(
1779                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1780                    f"float values of the same length as suggestion_classes (got a list of length "
1781                    f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)"
1782                )
1783        else:
1784            exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes]
1785        if isinstance(include_threshold, list):
1786            if len(include_threshold) != len(include_classes):
1787                raise ValueError(
1788                    "The exclude_threshold parameter has to be either a float value or a list of "
1789                    f"float values of the same length as exclude_classes (got a list of length "
1790                    f"{len(include_threshold)} for {len(include_classes)} classes)"
1791                )
1792        else:
1793            include_threshold = [include_threshold for _ in include_classes]
1794        if isinstance(include_threshold_diff, list):
1795            if len(include_threshold_diff) != len(include_classes):
1796                raise ValueError(
1797                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1798                    f"float values of the same length as exclude_classes (got a list of length "
1799                    f"{len(include_threshold_diff)} for {len(include_classes)} classes)"
1800                )
1801        else:
1802            include_threshold_diff = [include_threshold_diff for _ in include_classes]
1803        if isinstance(include_hysteresis, list):
1804            if len(include_hysteresis) != len(include_classes):
1805                raise ValueError(
1806                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1807                    f"float values of the same length as suggestion_classes (got a list of length "
1808                    f"{len(include_hysteresis)} for {len(include_classes)} classes)"
1809                )
1810        else:
1811            include_hysteresis = [include_hysteresis for _ in include_classes]
1812        if (suggestion_episodes is None and suggestions_prediction is None) and len(
1813            exclude_classes
1814        ) > 0:
1815            raise ValueError(
1816                "In order to exclude classes from the active learning intervals you need to set the "
1817                "suggestion_episode parameter"
1818            )
1819
1820        task = None
1821        if error_episode is not None:
1822            task, parameters, mode = self._make_task_prediction(
1823                prediction_name=suggestions_name,
1824                load_episode=error_episode,
1825                parameters_update=parameters_update,
1826                load_epoch=error_load_epoch,
1827                data_path=data_path,
1828                mode=mode,
1829                file_paths=file_paths,
1830                task=task,
1831            )
1832            predicted_error = task.predict(
1833                data=mode,
1834                raw_output=True,
1835                apply_primary_function=True,
1836                augment_n=augment_n,
1837            )
1838        else:
1839            predicted_error = None
1840
1841        if suggestion_episodes is not None:
1842            (
1843                task,
1844                parameters,
1845                mode,
1846                predicted_classes,
1847                inference_time,
1848                behavior_dict,
1849            ) = self._make_prediction(
1850                prediction_name=suggestions_name,
1851                episode_names=suggestion_episodes,
1852                load_epochs=suggestion_load_epoch,
1853                parameters_update=parameters_update,
1854                data_path=data_path,
1855                file_paths=file_paths,
1856                mode=mode,
1857                task=task,
1858            )
1859        elif suggestions_prediction is not None:
1860            with open(
1861                os.path.join(
1862                    self.project_path,
1863                    "results",
1864                    "predictions",
1865                    f"{suggestions_prediction}.pickle",
1866                ),
1867                "rb",
1868            ) as f:
1869                predicted_classes = pickle.load(f)
1870            if parameters_update is None:
1871                parameters_update = {}
1872            parameters_update = self._update(
1873                self._predictions().load_parameters(suggestions_prediction),
1874                parameters_update,
1875            )
1876            parameters_update.pop("model")
1877            if suggestion_episodes is None:
1878                suggestion_episodes = [
1879                    os.path.basename(
1880                        os.path.dirname(
1881                            parameters_update["training"]["checkpoint_path"]
1882                        )
1883                    )
1884                ]
1885            task, parameters, mode = self._make_task_prediction(
1886                "_",
1887                load_episode=None,
1888                parameters_update=parameters_update,
1889                data_path=data_path,
1890                file_paths=file_paths,
1891                mode=mode,
1892            )
1893        else:
1894            predicted_classes = None
1895
1896        if len(suggestion_classes) > 0 and predicted_classes is not None:
1897            suggestions = self._make_suggestions(
1898                task,
1899                predicted_error,
1900                predicted_classes,
1901                suggestion_threshold,
1902                suggestion_threshold_diff,
1903                suggestion_hysteresis,
1904                suggestion_episodes,
1905                suggestion_classes,
1906                error_threshold,
1907                min_frames_suggestion,
1908                min_frames_al,
1909                visibility_min_score,
1910                visibility_min_frac,
1911                cut_annotated=cut_annotated,
1912            )
1913            videos = list(suggestions.keys())
1914            for v_id in videos:
1915                times_dict = defaultdict(lambda: defaultdict(lambda: []))
1916                clips = set()
1917                for c in suggestions[v_id]:
1918                    for start, end, ind in suggestions[v_id][c]:
1919                        times_dict[ind][c].append([start, end, 2])
1920                        clips.add(ind)
1921                clips = list(clips)
1922                times_dict = dict(times_dict)
1923                times = [
1924                    [times_dict[ind][c] for c in suggestion_classes] for ind in clips
1925                ]
1926                save_path = self._suggestion_path(v_id, suggestions_name)
1927                Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1928                with open(save_path, "wb") as f:
1929                    pickle.dump((None, suggestion_classes, clips, times), f)
1930
1931        if (
1932            error_episode is not None
1933            or len(exclude_classes) > 0
1934            or len(include_classes) > 0
1935        ):
1936            al_points = self._make_al_points(
1937                task,
1938                predicted_error,
1939                predicted_classes,
1940                exclude_classes,
1941                exclude_threshold,
1942                exclude_threshold_diff,
1943                exclude_hysteresis,
1944                include_classes,
1945                include_threshold,
1946                include_threshold_diff,
1947                include_hysteresis,
1948                error_episode,
1949                error_class,
1950                suggestion_episodes,
1951                error_threshold,
1952                error_threshold_diff,
1953                error_hysteresis,
1954                min_frames_al,
1955                visibility_min_score,
1956                visibility_min_frac,
1957            )
1958        else:
1959            al_points = self._make_al_points_from_suggestions(
1960                suggestions_name,
1961                task,
1962                predicted_classes,
1963                background_threshold,
1964                visibility_min_score,
1965                visibility_min_frac,
1966                num_behaviors=len(task.behaviors_dict()),
1967            )
1968        save_path = self._al_points_path(suggestions_name)
1969        Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1970        with open(save_path, "wb") as f:
1971            pickle.dump(al_points, f)
1972
1973        meta_parameters = {
1974            "error_episode": error_episode,
1975            "error_load_epoch": error_load_epoch,
1976            "error_class": error_class,
1977            "suggestion_episode": suggestion_episodes,
1978            "suggestion_load_epoch": suggestion_load_epoch,
1979            "suggestion_classes": suggestion_classes,
1980            "error_threshold": error_threshold,
1981            "error_threshold_diff": error_threshold_diff,
1982            "error_hysteresis": error_hysteresis,
1983            "suggestion_threshold": suggestion_threshold,
1984            "suggestion_threshold_diff": suggestion_threshold_diff,
1985            "suggestion_hysteresis": suggestion_hysteresis,
1986            "min_frames_suggestion": min_frames_suggestion,
1987            "min_frames_al": min_frames_al,
1988            "visibility_min_score": visibility_min_score,
1989            "visibility_min_frac": visibility_min_frac,
1990            "augment_n": augment_n,
1991            "exclude_classes": exclude_classes,
1992            "exclude_threshold": exclude_threshold,
1993            "exclude_threshold_diff": exclude_threshold_diff,
1994            "exclude_hysteresis": exclude_hysteresis,
1995        }
1996        self._save_suggestions(suggestions_name, {}, meta_parameters)
1997        if data_path is not None or file_paths is not None or remove_saved_features:
1998            self._remove_stores(parameters)
1999        print(f"\n")
2000
2001    def _generate_similarity_score(
2002        self,
2003        prediction_name: str,
2004        target_video_id: str,
2005        target_clip: str,
2006        target_start: int,
2007        target_end: int,
2008    ) -> Dict:
2009        with open(
2010            os.path.join(
2011                self.project_path,
2012                "results",
2013                "predictions",
2014                f"{prediction_name}.pickle",
2015            ),
2016            "rb",
2017        ) as f:
2018            prediction = pickle.load(f)
2019        target = prediction[target_video_id][target_clip][:, target_start:target_end]
2020        score_dict = defaultdict(lambda: {})
2021        for video_id in prediction:
2022            for clip_id in prediction[video_id]:
2023                score_dict[video_id][clip_id] = torch.cdist(
2024                    target.T, prediction[video_id][score_dict].T
2025                ).min(0)
2026        return score_dict
2027
2028    def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict:
2029        """Suggest intervals from a score dictionary.
2030
2031        Parameters
2032        ----------
2033        score_dict : dict
2034            a dictionary containing scores for intervals
2035        min_length : int
2036            minimum length of intervals to suggest
2037        n_intervals : int
2038            number of intervals to suggest
2039
2040        Returns
2041        -------
2042        intervals : dict
2043            a dictionary of suggested intervals
2044
2045        """
2046        interval_address = {}
2047        interval_value = {}
2048        s = 0
2049        n = 0
2050        for video_id, video_dict in score_dict.items():
2051            for clip_id, value in video_dict.items():
2052                s += value.mean()
2053                n += 1
2054        mean_value = s / n
2055        alpha = 1.75
2056        for it in range(10):
2057            id = 0
2058            interval_address = {}
2059            interval_value = {}
2060            for video_id, video_dict in score_dict.items():
2061                for clip_id, value in video_dict.items():
2062                    res_indices_start, res_indices_end = apply_threshold(
2063                        value,
2064                        threshold=(2 - alpha * (0.9**it)) * mean_value,
2065                        low=True,
2066                        error_mask=None,
2067                        min_frames=min_length,
2068                        smooth_interval=0,
2069                    )
2070                    for start, end in zip(res_indices_start, res_indices_end):
2071                        interval_address[id] = [video_id, clip_id, start, end]
2072                        interval_value[id] = score_dict[video_id][clip_id][
2073                            start:end
2074                        ].mean()
2075                        id += 1
2076            if len(interval_address) >= n_intervals:
2077                break
2078        if len(interval_address) < n_intervals:
2079            warnings.warn(
2080                f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals"
2081            )
2082        sorted_intervals = sorted(
2083            interval_value.items(), key=lambda x: x[1], reverse=True
2084        )
2085        output_intervals = [
2086            interval_address[x[0]]
2087            for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)]
2088        ]
2089        output = defaultdict(lambda: [])
2090        for video_id, clip_id, start, end in output_intervals:
2091            output[video_id].append([start, end, clip_id])
2092        return output
2093
2094    def suggest_intervals_with_similarity(
2095        self,
2096        suggestions_name: str,
2097        prediction_name: str,
2098        target_video_id: str,
2099        target_clip: str,
2100        target_start: int,
2101        target_end: int,
2102        min_length: int = 60,
2103        n_intervals: int = 5,
2104        force: bool = False,
2105    ):
2106        """
2107        Suggest intervals based on similarity to a target interval.
2108
2109        Parameters
2110        ----------
2111        suggestions_name : str
2112            Name of the suggestion.
2113        prediction_name : str
2114            Name of the prediction to use.
2115        target_video_id : str
2116            Video id of the target interval.
2117        target_clip : str
2118            Clip id of the target interval.
2119        target_start : int
2120            Start frame of the target interval.
2121        target_end : int
2122            End frame of the target interval.
2123        min_length : int, default 60
2124            Minimum length of the suggested intervals.
2125        n_intervals : int, default 5
2126            Number of suggested intervals.
2127        force : bool, default False
2128            If True, the suggestion is overwritten if it already exists.
2129
2130        """
2131        self._check_suggestions_validity(suggestions_name, force=force)
2132        print(f"SUGGESTION {suggestions_name}")
2133        score_dict = self._generate_similarity_score(
2134            prediction_name, target_video_id, target_clip, target_start, target_end
2135        )
2136        intervals = self._suggest_intervals_from_dict(
2137            score_dict, min_length, n_intervals
2138        )
2139        suggestions_path = os.path.join(
2140            self.project_path,
2141            "results",
2142            "suggestions",
2143            suggestions_name,
2144        )
2145        if not os.path.exists(suggestions_path):
2146            os.mkdir(suggestions_path)
2147        with open(
2148            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2149        ) as f:
2150            pickle.dump(intervals, f)
2151        meta_parameters = {
2152            "prediction_name": prediction_name,
2153            "min_frames_suggestion": min_length,
2154            "n_intervals": n_intervals,
2155            "target_clip": target_clip,
2156            "target_start": target_start,
2157            "target_end": target_end,
2158        }
2159        self._save_suggestions(suggestions_name, {}, meta_parameters)
2160        print("\n")
2161
2162    def suggest_intervals_with_uncertainty(
2163        self,
2164        suggestions_name: str,
2165        episode_names: List,
2166        load_epochs: List = None,
2167        classes: List = None,
2168        n_frames: int = 10000,
2169        method: str = "least_confidence",
2170        min_length: int = 60,
2171        augment_n: int = 0,
2172        data_path: str = None,
2173        file_paths: Set = None,
2174        parameters_update: Dict = None,
2175        mode: str = "all",
2176        force: bool = False,
2177        remove_saved_features: bool = False,
2178    ) -> None:
2179        """Generate an active learning file based on model uncertainty.
2180
2181        If you provide several episode names, the predicted probabilities will be averaged.
2182
2183        Parameters
2184        ----------
2185        suggestions_name : str
2186            the name of the suggestion
2187        episode_names : list
2188            a list of string episode names to load the models from
2189        load_epochs : list, optional
2190            a list of epoch indices to load the models from (if `None`, the last ones will be used)
2191        classes : list, optional
2192            a list of classes to look at (by default all)
2193        n_frames : int, default 10000
2194            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2195            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2196            with the set parameters)
2197        method : {"least_confidence", "entropy"}
2198            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
2199            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
2200        min_length : int, default 60
2201            the minimum number of frames in one interval
2202        augment_n : int, default 0
2203            the number of augmentations to average the predictions over
2204        data_path : str, optional
2205            the path to a data folder (by default, the project data is used)
2206        file_paths : set, optional
2207            a list of file paths (by default, the project data is used)
2208        parameters_update : dict, optional
2209            a dictionary of parameter updates
2210        mode : {"test", "val", "train", "all"}
2211            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2212            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2213        force : bool, default False
2214            if `True`, existing suggestions with the same name will be overwritten
2215        remove_saved_features : bool, default False
2216            if `True`, the dataset will be deleted after the computation
2217
2218        """
2219        self._check_suggestions_validity(suggestions_name, force=force)
2220        print(f"SUGGESTION {suggestions_name}")
2221        task, parameters, mode, predicted, inference_time, behavior_dict = (
2222            self._make_prediction(
2223                suggestions_name,
2224                episode_names,
2225                load_epochs,
2226                parameters_update,
2227                data_path=data_path,
2228                file_paths=file_paths,
2229                mode=mode,
2230                augment_n=augment_n,
2231                evaluate=False,
2232            )
2233        )
2234        if classes is None:
2235            classes = behavior_dict.values()
2236        episode = self._episodes().get_runs(episode_names[0])[0]
2237        score_tensors = task.generate_uncertainty_score(
2238            classes,
2239            augment_n,
2240            method,
2241            predicted,
2242            self._episode(episode).get_behaviors_dict(),
2243        )
2244        intervals = self._suggest_intervals(
2245            task.dataset(mode), score_tensors, n_frames, min_length
2246        )
2247        for k, v in intervals.items():
2248            l = sum([x[1] - x[0] for x in v])
2249            print(f"{k}: {len(v)} ({l})")
2250        if remove_saved_features:
2251            self._remove_stores(parameters)
2252        suggestions_path = os.path.join(
2253            self.project_path,
2254            "results",
2255            "suggestions",
2256            suggestions_name,
2257        )
2258        if not os.path.exists(suggestions_path):
2259            os.mkdir(suggestions_path)
2260        with open(
2261            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2262        ) as f:
2263            pickle.dump(intervals, f)
2264        meta_parameters = {
2265            "suggestion_episode": episode_names,
2266            "suggestion_load_epoch": load_epochs,
2267            "suggestion_classes": classes,
2268            "min_frames_suggestion": min_length,
2269            "augment_n": augment_n,
2270            "method": method,
2271            "num_frames": n_frames,
2272        }
2273        self._save_suggestions(suggestions_name, {}, meta_parameters)
2274        print("\n")
2275
2276    def suggest_intervals_with_bald(
2277        self,
2278        suggestions_name: str,
2279        episode_name: str,
2280        load_epoch: int = None,
2281        classes: List = None,
2282        n_frames: int = 10000,
2283        num_models: int = 10,
2284        kernel_size: int = 11,
2285        min_length: int = 60,
2286        augment_n: int = 0,
2287        data_path: str = None,
2288        file_paths: Set = None,
2289        parameters_update: Dict = None,
2290        mode: str = "all",
2291        force: bool = False,
2292        remove_saved_features: bool = False,
2293    ):
2294        """Generate an active learning file based on Bayesian Active Learning by Disagreement.
2295
2296        Parameters
2297        ----------
2298        suggestions_name : str
2299            the name of the suggestion
2300        episode_name : str
2301            the name of the episode to load the model from
2302        load_epoch : int, optional
2303            the index of the epoch to load the model from (if `None`, the last one will be used)
2304        classes : list, optional
2305            a list of classes to look at (by default all)
2306        n_frames : int, default 10000
2307            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2308            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2309            with the set parameters)
2310        num_models : int, default 10
2311            the number of dropout masks to apply
2312        kernel_size : int, default 11
2313            the size of the smoothing kernel applied to the discrete results
2314        min_length : int, default 60
2315            the minimum number of frames in one interval
2316        augment_n : int, default 0
2317            the number of augmentations to average the predictions over
2318        data_path : str, optional
2319            the path to a data folder (by default, the project data is used)
2320        file_paths : set, optional
2321            a list of file paths (by default, the project data is used)
2322        parameters_update : dict, optional
2323            a dictionary of parameter updates
2324        mode : {"test", "val", "train", "all"}
2325            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2326            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2327        force : bool, default False
2328            if `True`, existing suggestions with the same name will be overwritten
2329        remove_saved_features : bool, default False
2330            if `True`, the dataset will be deleted after the computation
2331
2332        """
2333        self._check_suggestions_validity(suggestions_name, force=force)
2334        print(f"SUGGESTION {suggestions_name}")
2335        task, parameters, mode = self._make_task_prediction(
2336            suggestions_name,
2337            episode_name,
2338            parameters_update,
2339            load_epoch,
2340            data_path=data_path,
2341            file_paths=file_paths,
2342            mode=mode,
2343        )
2344        if classes is None:
2345            classes = list(task.behaviors_dict().values())
2346        score_tensors = task.generate_bald_score(
2347            classes, augment_n, num_models, kernel_size
2348        )
2349        intervals = self._suggest_intervals(
2350            task.dataset(mode), score_tensors, n_frames, min_length
2351        )
2352        if remove_saved_features:
2353            self._remove_stores(parameters)
2354        suggestions_path = os.path.join(
2355            self.project_path,
2356            "results",
2357            "suggestions",
2358            suggestions_name,
2359        )
2360        if not os.path.exists(suggestions_path):
2361            os.mkdir(suggestions_path)
2362        with open(
2363            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2364        ) as f:
2365            pickle.dump(intervals, f)
2366        meta_parameters = {
2367            "suggestion_episode": episode_name,
2368            "suggestion_load_epoch": load_epoch,
2369            "suggestion_classes": classes,
2370            "min_frames_suggestion": min_length,
2371            "augment_n": augment_n,
2372            "method": f"BALD:{num_models}",
2373            "num_frames": n_frames,
2374        }
2375        self._save_suggestions(suggestions_name, {}, meta_parameters)
2376        print("\n")
2377
2378    def list_episodes(
2379        self,
2380        episode_names: List = None,
2381        value_filter: str = "",
2382        display_parameters: List = None,
2383        print_results: bool = True,
2384    ) -> pd.DataFrame:
2385        """Get a filtered pandas dataframe with episode metadata.
2386
2387        Parameters
2388        ----------
2389        episode_names : list
2390            a list of strings of episode names
2391        value_filter : str
2392            a string of filters to apply; of this general structure:
2393            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
2394            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
2395        display_parameters : list
2396            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2397        print_results : bool, default True
2398            if True, the result will be printed to standard output
2399
2400        Returns
2401        -------
2402        pd.DataFrame
2403            the filtered dataframe
2404
2405        """
2406        episodes = self._episodes().list_episodes(
2407            episode_names, value_filter, display_parameters
2408        )
2409        if print_results:
2410            print("TRAINING EPISODES")
2411            print(episodes)
2412            print("\n")
2413        return episodes
2414
2415    def list_predictions(
2416        self,
2417        episode_names: List = None,
2418        value_filter: str = "",
2419        display_parameters: List = None,
2420        print_results: bool = True,
2421    ) -> pd.DataFrame:
2422        """Get a filtered pandas dataframe with prediction metadata.
2423
2424        Parameters
2425        ----------
2426        episode_names : list
2427            a list of strings of episode names
2428        value_filter : str
2429            a string of filters to apply; of this general structure:
2430            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2431            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2432        display_parameters : list
2433            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2434        print_results : bool, default True
2435            if True, the result will be printed to standard output
2436
2437        Returns
2438        -------
2439        pd.DataFrame
2440            the filtered dataframe
2441
2442        """
2443        predictions = self._predictions().list_episodes(
2444            episode_names, value_filter, display_parameters
2445        )
2446        if print_results:
2447            print("PREDICTIONS")
2448            print(predictions)
2449            print("\n")
2450        return predictions
2451
2452    def list_suggestions(
2453        self,
2454        suggestions_names: List = None,
2455        value_filter: str = "",
2456        display_parameters: List = None,
2457        print_results: bool = True,
2458    ) -> pd.DataFrame:
2459        """Get a filtered pandas dataframe with prediction metadata.
2460
2461        Parameters
2462        ----------
2463        suggestions_names : list
2464            a list of strings of suggestion names
2465        value_filter : str
2466            a string of filters to apply; of this general structure:
2467            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2468            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2469        display_parameters : list
2470            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2471        print_results : bool, default True
2472            if True, the result will be printed to standard output
2473
2474        Returns
2475        -------
2476        pd.DataFrame
2477            the filtered dataframe
2478
2479        """
2480        suggestions = self._suggestions().list_episodes(
2481            suggestions_names, value_filter, display_parameters
2482        )
2483        if print_results:
2484            print("SUGGESTIONS")
2485            print(suggestions)
2486            print("\n")
2487        return suggestions
2488
2489    def list_searches(
2490        self,
2491        search_names: List = None,
2492        value_filter: str = "",
2493        display_parameters: List = None,
2494        print_results: bool = True,
2495    ) -> pd.DataFrame:
2496        """Get a filtered pandas dataframe with hyperparameter search metadata.
2497
2498        Parameters
2499        ----------
2500        search_names : list
2501            a list of strings of search names
2502        value_filter : str
2503            a string of filters to apply; of this general structure:
2504            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2505            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2506        display_parameters : list
2507            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2508        print_results : bool, default True
2509            if True, the result will be printed to standard output
2510
2511        Returns
2512        -------
2513        pd.DataFrame
2514            the filtered dataframe
2515
2516        """
2517        searches = self._searches().list_episodes(
2518            search_names, value_filter, display_parameters
2519        )
2520        if print_results:
2521            print("SEARCHES")
2522            print(searches)
2523            print("\n")
2524        return searches
2525
2526    def get_best_parameters(
2527        self,
2528        search_name: str,
2529        round_to_binary: List = None,
2530    ):
2531        """Get the best parameters found by a search.
2532
2533        Parameters
2534        ----------
2535        search_name : str
2536            the name of the search
2537        round_to_binary : list, default None
2538            a list of parameters to round to binary values
2539
2540        Returns
2541        -------
2542        best_params : dict
2543            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2544
2545        """
2546        params, model = self._searches().get_best_params(
2547            search_name, round_to_binary=round_to_binary
2548        )
2549        params = self._update(params, {"general": {"model_name": model}})
2550        return params
2551
2552    def list_best_parameters(
2553        self, search_name: str, print_results: bool = True
2554    ) -> Dict:
2555        """Get the raw dictionary of best parameters found by a search.
2556
2557        Parameters
2558        ----------
2559        search_name : str
2560            the name of the search
2561        print_results : bool, default True
2562            if True, the result will be printed to standard output
2563
2564        Returns
2565        -------
2566        best_params : dict
2567            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2568
2569        """
2570        params = self._searches().get_best_params_raw(search_name)
2571        if print_results:
2572            print(f"SEARCH RESULTS {search_name}")
2573            for k, v in params.items():
2574                print(f"{k}: {v}")
2575            print("\n")
2576        return params
2577
2578    def plot_episodes(
2579        self,
2580        episode_names: List,
2581        metrics: List | str,
2582        modes: List | str = None,
2583        title: str = None,
2584        episode_labels: List = None,
2585        save_path: str = None,
2586        add_hlines: List = None,
2587        epoch_limits: List = None,
2588        colors: List = None,
2589        add_highpoint_hlines: bool = False,
2590        remove_box: bool = False,
2591        font_size: float = None,
2592        linewidth: float = None,
2593        return_ax: bool = False,
2594    ) -> None:
2595        """Plot episode training curves.
2596
2597        Parameters
2598        ----------
2599        episode_names : list
2600            a list of episode names to plot; to plot to episodes in one line combine them in a list
2601            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
2602        metrics : list
2603            a list of metric to plot
2604        modes : list, optional
2605            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
2606        title : str, optional
2607            title for the plot
2608        episode_labels : list, optional
2609            a list of strings used to label the curves (has to be the same length as episode_names)
2610        save_path : str, optional
2611            the path to save the resulting plot
2612        add_hlines : list, optional
2613            a list of float values (or (value, label) tuples) to mark with horizontal lines
2614        epoch_limits : list, optional
2615            a list of (min, max) tuples to set the x-axis limits for each episode
2616        colors: list, optional
2617            a list of matplotlib colors
2618        add_highpoint_hlines : bool, default False
2619            if `True`, horizontal lines will be added at the highest value of each episode
2620        """
2621
2622        if isinstance(metrics, str):
2623            metrics = [metrics]
2624        if isinstance(modes, str):
2625            modes = [modes]
2626
2627        if font_size is not None:
2628            font = {"size": font_size}
2629            rc("font", **font)
2630        if modes is None:
2631            modes = ["val"]
2632        if add_hlines is None:
2633            add_hlines = []
2634        logs = []
2635        epochs = []
2636        labels = []
2637        if episode_labels is not None:
2638            assert len(episode_labels) == len(episode_names)
2639        for name_i, name in enumerate(episode_names):
2640            log_params = product(metrics, modes)
2641            for metric, mode in log_params:
2642                if episode_labels is not None:
2643                    label = episode_labels[name_i]
2644                else:
2645                    label = deepcopy(name)
2646                if len(modes) != 1:
2647                    label += f"_{mode}"
2648                if len(metrics) != 1:
2649                    label += f"_{metric}"
2650                labels.append(label)
2651                if isinstance(name, Iterable) and not isinstance(name, str):
2652                    epoch_list = defaultdict(lambda: [])
2653                    multi_logs = defaultdict(lambda: [])
2654                    for i, n in enumerate(name):
2655                        runs = self._episodes().get_runs(n)
2656                        if len(runs) > 1:
2657                            for run in runs:
2658                                if "::" in run:
2659                                    index = run.split("::")[-1]
2660                                else:
2661                                    index = run.split("#")[-1]
2662                                if multi_logs[index] == []:
2663                                    if multi_logs["null"] is None:
2664                                        raise RuntimeError(
2665                                            "The run indices are not consistent across episodes!"
2666                                        )
2667                                    else:
2668                                        multi_logs[index] += multi_logs["null"]
2669                                multi_logs[index] += list(
2670                                    self._episode(run).get_metric_log(mode, metric)
2671                                )
2672                                start = (
2673                                    0
2674                                    if len(epoch_list[index]) == 0
2675                                    else epoch_list[index][-1]
2676                                )
2677                                epoch_list[index] += [
2678                                    x + start
2679                                    for x in self._episode(run).get_epoch_list(mode)
2680                                ]
2681                            multi_logs["null"] = None
2682                        else:
2683                            if len(multi_logs.keys()) > 1:
2684                                raise RuntimeError(
2685                                    "Cannot plot a single-run episode after a multi-run episode!"
2686                                )
2687                            multi_logs["null"] += list(
2688                                self._episode(n).get_metric_log(mode, metric)
2689                            )
2690                            start = (
2691                                0
2692                                if len(epoch_list["null"]) == 0
2693                                else epoch_list["null"][-1]
2694                            )
2695                            epoch_list["null"] += [
2696                                x + start for x in self._episode(n).get_epoch_list(mode)
2697                            ]
2698                    if len(multi_logs.keys()) == 1:
2699                        log = multi_logs["null"]
2700                        epochs.append(epoch_list["null"])
2701                    else:
2702                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
2703                        epochs.append(
2704                            tuple([v for k, v in epoch_list.items() if k != "null"])
2705                        )
2706                else:
2707                    runs = self._episodes().get_runs(name)
2708                    if len(runs) > 1:
2709                        log = []
2710                        for run in runs:
2711                            tracked_metrics = self._episode(run).get_metrics()
2712                            if metric in tracked_metrics:
2713                                log.append(
2714                                    list(
2715                                        self._episode(run).get_metric_log(mode, metric)
2716                                    )
2717                                )
2718                            else:
2719                                relevant = []
2720                                for m in tracked_metrics:
2721                                    m_split = m.split("_")
2722                                    if (
2723                                        "_".join(m_split[:-1]) == metric
2724                                        and m_split[-1].isnumeric()
2725                                    ):
2726                                        relevant.append(m)
2727                                if len(relevant) == 0:
2728                                    raise ValueError(
2729                                        f"The {metric} metric was not tracked at {run}"
2730                                    )
2731                                arr = 0
2732                                for m in relevant:
2733                                    arr += self._episode(run).get_metric_log(mode, m)
2734                                arr /= len(relevant)
2735                                log.append(list(arr))
2736                        log = tuple(log)
2737                        epochs.append(
2738                            tuple(
2739                                [
2740                                    self._episode(run).get_epoch_list(mode)
2741                                    for run in runs
2742                                ]
2743                            )
2744                        )
2745                    else:
2746                        tracked_metrics = self._episode(name).get_metrics()
2747                        if metric in tracked_metrics:
2748                            log = list(self._episode(name).get_metric_log(mode, metric))
2749                        else:
2750                            relevant = []
2751                            for m in tracked_metrics:
2752                                m_split = m.split("_")
2753                                if (
2754                                    "_".join(m_split[:-1]) == metric
2755                                    and m_split[-1].isnumeric()
2756                                ):
2757                                    relevant.append(m)
2758                            if len(relevant) == 0:
2759                                raise ValueError(
2760                                    f"The {metric} metric was not tracked at {name}"
2761                                )
2762                            arr = 0
2763                            for m in relevant:
2764                                arr += self._episode(name).get_metric_log(mode, m)
2765                            arr /= len(relevant)
2766                            log = list(arr)
2767                        epochs.append(self._episode(name).get_epoch_list(mode))
2768                logs.append(log)
2769        # if episode_labels is not None:
2770        #     print(f'{len(episode_labels)=}, {len(logs)=}')
2771        #     if len(episode_labels) != len(logs):
2772
2773        #         raise ValueError(
2774        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
2775        #             f"curves ({len(logs)})!"
2776        #         )
2777        #     else:
2778        #         labels = episode_labels
2779        if colors is None:
2780            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
2781        if len(colors) != len(logs):
2782            raise ValueError(
2783                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
2784            )
2785        f, ax = plt.subplots()
2786        length = 0
2787        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
2788            if type(log) is list:
2789                if len(log) > length:
2790                    length = len(log)
2791                ax.plot(
2792                    epoch_list,
2793                    log,
2794                    label=label,
2795                    color=color,
2796                )
2797                if add_highpoint_hlines:
2798                    ax.axhline(np.max(log), linestyle="dashed", color=color)
2799            else:
2800                for l, xx in zip(log, epoch_list):
2801                    if len(l) > length:
2802                        length = len(l)
2803                    ax.plot(
2804                        xx,
2805                        l,
2806                        color=color,
2807                        alpha=0.2,
2808                    )
2809                if not all([len(x) == len(log[0]) for x in log]):
2810                    warnings.warn(
2811                        f"Got logs with unequal lengths in parallel runs for {label}"
2812                    )
2813                    log = list(log)
2814                    epoch_list = list(epoch_list)
2815                    for i, x in enumerate(epoch_list):
2816                        to_remove = []
2817                        for j, y in enumerate(x[1:]):
2818                            if y <= x[j - 1]:
2819                                y_ind = x.index(y)
2820                                to_remove += list(range(y_ind, j))
2821                        epoch_list[i] = [
2822                            y for j, y in enumerate(x) if j not in to_remove
2823                        ]
2824                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2825                    length = min([len(x) for x in log])
2826                    for i in range(len(log)):
2827                        log[i] = log[i][:length]
2828                        epoch_list[i] = epoch_list[i][:length]
2829                    if not all([x == epoch_list[0] for x in epoch_list]):
2830                        raise RuntimeError(
2831                            f"Got different epoch indices in parallel runs for {label}"
2832                        )
2833                mean = np.array(log).mean(0)
2834                ax.plot(
2835                    epoch_list[0],
2836                    mean,
2837                    label=label,
2838                    color=color,
2839                    linewidth=linewidth,
2840                )
2841                if add_highpoint_hlines:
2842                    ax.axhline(np.max(mean), linestyle="dashed", color=color)
2843        for x in add_hlines:
2844            label = None
2845            if isinstance(x, Iterable):
2846                x, label = x
2847            ax.axhline(x, label=label)
2848            ax.set_xlim((0, length))
2849
2850        ax.legend()
2851        ax.set_xlabel("epochs")
2852        if len(metrics) == 1:
2853            ax.set_ylabel(metrics[0])
2854        else:
2855            ax.set_ylabel("value")
2856        if title is None:
2857            if len(episode_names) == 1:
2858                title = episode_names[0]
2859            elif len(metrics) == 1:
2860                title = metrics[0]
2861        if epoch_limits is not None:
2862            ax.set_xlim(epoch_limits)
2863        if title is not None:
2864            ax.set_title(title)
2865        if remove_box:
2866            ax.box(False)
2867        if return_ax:
2868            return ax
2869        if save_path is not None:
2870            plt.savefig(save_path)
2871        plt.show()
2872
2873    def update_parameters(
2874        self,
2875        parameters_update: Dict = None,
2876        load_search: str = None,
2877        load_parameters: List = None,
2878        round_to_binary: List = None,
2879    ) -> None:
2880        """Update the parameters in the project config files.
2881
2882        Parameters
2883        ----------
2884        parameters_update : dict, optional
2885            a dictionary of parameter updates
2886        load_search : str, optional
2887            the name of hyperparameter search results to load to config
2888        load_parameters : list, optional
2889            a list of lists of string names of the parameters to load from the searches
2890        round_to_binary : list, optional
2891            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2892
2893        """
2894        keys = [
2895            "general",
2896            "losses",
2897            "metrics",
2898            "ssl",
2899            "training",
2900            "data",
2901        ]
2902        parameters = self._read_parameters(catch_blanks=False)
2903        if parameters_update is not None:
2904            model_params = (
2905                parameters_update.pop("model") if "model" in parameters_update else None
2906            )
2907            feat_params = (
2908                parameters_update.pop("features")
2909                if "features" in parameters_update
2910                else None
2911            )
2912            aug_params = (
2913                parameters_update.pop("augmentations")
2914                if "augmentations" in parameters_update
2915                else None
2916            )
2917
2918            parameters = self._update(parameters, parameters_update)
2919            model_name = parameters["general"]["model_name"]
2920            parameters["model"] = self._open_yaml(
2921                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2922            )
2923            if model_params is not None:
2924                parameters["model"] = self._update(parameters["model"], model_params)
2925            feat_name = parameters["general"]["feature_extraction"]
2926            parameters["features"] = self._open_yaml(
2927                os.path.join(
2928                    self.project_path, "config", "features", f"{feat_name}.yaml"
2929                )
2930            )
2931            if feat_params is not None:
2932                parameters["features"] = self._update(
2933                    parameters["features"], feat_params
2934                )
2935            aug_name = options.extractor_to_transformer[
2936                parameters["general"]["feature_extraction"]
2937            ]
2938            parameters["augmentations"] = self._open_yaml(
2939                os.path.join(
2940                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2941                )
2942            )
2943            if aug_params is not None:
2944                parameters["augmentations"] = self._update(
2945                    parameters["augmentations"], aug_params
2946                )
2947        if load_search is not None:
2948            parameters_update, model_name = self._searches().get_best_params(
2949                load_search, load_parameters, round_to_binary
2950            )
2951            parameters["general"]["model_name"] = model_name
2952            parameters["model"] = self._open_yaml(
2953                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2954            )
2955            parameters = self._update(parameters, parameters_update)
2956        for key in keys:
2957            with open(
2958                os.path.join(self.project_path, "config", f"{key}.yaml"),
2959                "w",
2960                encoding="utf-8",
2961            ) as f:
2962                YAML().dump(parameters[key], f)
2963        model_name = parameters["general"]["model_name"]
2964        model_path = os.path.join(
2965            self.project_path, "config", "model", f"{model_name}.yaml"
2966        )
2967        with open(model_path, "w", encoding="utf-8") as f:
2968            YAML().dump(parameters["model"], f)
2969        features_name = parameters["general"]["feature_extraction"]
2970        features_path = os.path.join(
2971            self.project_path, "config", "features", f"{features_name}.yaml"
2972        )
2973        with open(features_path, "w", encoding="utf-8") as f:
2974            YAML().dump(parameters["features"], f)
2975        aug_name = options.extractor_to_transformer[features_name]
2976        aug_path = os.path.join(
2977            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2978        )
2979        with open(aug_path, "w", encoding="utf-8") as f:
2980            YAML().dump(parameters["augmentations"], f)
2981
2982    def get_summary(
2983        self,
2984        episode_names: list,
2985        method: str = "last",
2986        average: int = 1,
2987        metrics: List = None,
2988        return_values: bool = False,
2989    ) -> Dict:
2990        """Get a summary of episode statistics.
2991
2992        If an episode has multiple runs, the statistics will be aggregated over all of them.
2993
2994        Parameters
2995        ----------
2996        episode_names : str
2997            the names of the episodes
2998        method : ["best", "last"]
2999            the method for choosing the epochs
3000        average : int, default 1
3001            the number of epochs to average over (for each run)
3002        metrics : list, optional
3003            a list of metrics
3004
3005        Returns
3006        -------
3007        statistics : dict
3008            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
3009            and 'std' for the standard deviation
3010
3011        """
3012        runs = []
3013        for episode_name in episode_names:
3014            runs_ep = self._episodes().get_runs(episode_name)
3015            if len(runs_ep) == 0:
3016                raise RuntimeError(
3017                    f"There is no {episode_name} episode in the project memory"
3018                )
3019            runs += runs_ep
3020        if metrics is None:
3021            metrics = self._episode(runs[0]).get_metrics()
3022
3023        values = {m: [] for m in metrics}
3024        for run in runs:
3025            for m in metrics:
3026                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
3027                if method == "best":
3028                    log = sorted(log)
3029                    values[m] += list(log[-average:])
3030                elif method == "last":
3031                    if len(log) == 0:
3032                        episodes = self._episodes().data
3033                        if average == 1 and ("results", m) in episodes.columns:
3034                            values[m] += [episodes.loc[run, ("results", m)]]
3035                        else:
3036                            raise RuntimeError(f"Did not find {m} metric for {run} run")
3037                    values[m] += list(log[-average:])
3038                elif method.startswith("epoch"):
3039                    epoch = int(method[5:]) - 1
3040                    pars = self._episodes().load_parameters(run)
3041                    step = int(pars["training"]["validation_interval"])
3042                    values[m] += [log[epoch // step]]
3043                else:
3044                    raise ValueError(
3045                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
3046                    )
3047        statistics = defaultdict(lambda: {})
3048        for m, v in values.items():
3049            statistics[m]["mean"] = np.mean(v)
3050            statistics[m]["std"] = np.std(v)
3051        print(f"SUMMARY {episode_names}")
3052        for m, v in statistics.items():
3053            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
3054        print("\n")
3055
3056        return (dict(statistics), values) if return_values else dict(statistics)
3057
3058    @staticmethod
3059    def remove_project(name: str, projects_path: str = None) -> None:
3060        """Remove all project files and experiment records and results.
3061
3062        Parameters
3063        ----------
3064        name : str
3065            the name of the project to remove
3066        projects_path : str, optional
3067            the path to the projects directory (by default the home DLC2Action directory)
3068
3069        """
3070        if projects_path is None:
3071            projects_path = os.path.join(str(Path.home()), "DLC2Action")
3072        project_path = os.path.join(projects_path, name)
3073        if os.path.exists(project_path):
3074            shutil.rmtree(project_path)
3075
3076    def remove_saved_features(
3077        self,
3078        dataset_names: List = None,
3079        exceptions: List = None,
3080        remove_active: bool = False,
3081    ) -> None:
3082        """Remove saved pre-computed dataset feature files.
3083
3084        By default, all features will be deleted.
3085        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
3086        while training or inference is happening though.
3087
3088        Parameters
3089        ----------
3090        dataset_names : list, optional
3091            a list of dataset names to delete (by default all names are added)
3092        exceptions : list, optional
3093            a list of dataset names to not be deleted
3094        remove_active : bool, default False
3095            if `False`, datasets used by unfinished episodes will not be deleted
3096
3097        """
3098        print("Removing datasets...")
3099        if dataset_names is None:
3100            dataset_names = []
3101        if exceptions is None:
3102            exceptions = []
3103        if not remove_active:
3104            exceptions += self._episodes().get_active_datasets()
3105        dataset_path = os.path.join(self.project_path, "saved_datasets")
3106        if os.path.exists(dataset_path):
3107            if dataset_names == []:
3108                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
3109
3110            to_remove = [
3111                x
3112                for x in dataset_names
3113                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
3114            ]
3115            if len(to_remove) > 2:
3116                to_remove = tqdm(to_remove)
3117            for dataset in to_remove:
3118                shutil.rmtree(os.path.join(dataset_path, dataset))
3119            to_remove = [
3120                f"{x}.pickle"
3121                for x in dataset_names
3122                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
3123                and x not in exceptions
3124            ]
3125            for dataset in to_remove:
3126                os.remove(os.path.join(dataset_path, dataset))
3127            names = self._saved_datasets().dataset_names()
3128            self._saved_datasets().remove(names)
3129        print("\n")
3130
3131    def remove_extra_checkpoints(
3132        self, episode_names: List = None, exceptions: List = None
3133    ) -> None:
3134        """Remove intermediate model checkpoint files (only leave the files for the last epoch).
3135
3136        By default, all intermediate checkpoints will be deleted.
3137        Files in the model folder that are not associated with any record in the meta files are also deleted.
3138
3139        Parameters
3140        ----------
3141        episode_names : list, optional
3142            a list of episode names to clean (by default all names are added)
3143        exceptions : list, optional
3144            a list of episode names to not clean
3145
3146        """
3147        model_path = os.path.join(self.project_path, "results", "model")
3148        try:
3149            all_names = self._episodes().data.index
3150        except:
3151            all_names = os.listdir(model_path)
3152        if episode_names is None:
3153            episode_names = all_names
3154        if exceptions is None:
3155            exceptions = []
3156        to_remove = [x for x in episode_names if x not in exceptions]
3157        folders = os.listdir(model_path)
3158        for folder in folders:
3159            if folder not in all_names:
3160                shutil.rmtree(os.path.join(model_path, folder))
3161            elif folder in to_remove:
3162                files = os.listdir(os.path.join(model_path, folder))
3163                for file in sorted(files)[:-1]:
3164                    os.remove(os.path.join(model_path, folder, file))
3165
3166    def remove_search(self, search_name: str) -> None:
3167        """Remove a hyperparameter search record.
3168
3169        Parameters
3170        ----------
3171        search_name : str
3172            the name of the search to remove
3173
3174        """
3175        self._searches().remove_episode(search_name)
3176        graph_path = os.path.join(self.project_path, "results", "searches", search_name)
3177        if os.path.exists(graph_path):
3178            shutil.rmtree(graph_path)
3179
3180    def remove_suggestion(self, suggestion_name: str) -> None:
3181        """Remove a suggestion record.
3182
3183        Parameters
3184        ----------
3185        suggestion_name : str
3186            the name of the suggestion to remove
3187
3188        """
3189        self._suggestions().remove_episode(suggestion_name)
3190        suggestion_path = os.path.join(
3191            self.project_path, "results", "suggestions", suggestion_name
3192        )
3193        if os.path.exists(suggestion_path):
3194            shutil.rmtree(suggestion_path)
3195
3196    def remove_prediction(self, prediction_name: str) -> None:
3197        """Remove a prediction record.
3198
3199        Parameters
3200        ----------
3201        prediction_name : str
3202            the name of the prediction to remove
3203
3204        """
3205        self._predictions().remove_episode(prediction_name)
3206        prediction_path = self.prediction_path(prediction_name)
3207        if os.path.exists(prediction_path):
3208            shutil.rmtree(prediction_path)
3209
3210    def check_prediction_exists(self, prediction_name: str) -> str | None:
3211        """Check if a prediction exists.
3212
3213        Parameters
3214        ----------
3215        prediction_name : str
3216            the name of the prediction to check
3217
3218        Returns
3219        -------
3220        str | None
3221            the path to the prediction if it exists, `None` otherwise
3222
3223        """
3224        prediction_path = self.prediction_path(prediction_name)
3225        if os.path.exists(prediction_path):
3226            return prediction_path
3227        return None
3228
3229    def remove_episode(self, episode_name: str) -> None:
3230        """Remove all model, logs and metafile records related to an episode.
3231
3232        Parameters
3233        ----------
3234        episode_name : str
3235            the name of the episode to remove
3236
3237        """
3238        runs = self._episodes().get_runs(episode_name)
3239        runs.append(episode_name)
3240        for run in runs:
3241            self._episodes().remove_episode(run)
3242            model_path = os.path.join(self.project_path, "results", "model", run)
3243            if os.path.exists(model_path):
3244                shutil.rmtree(model_path)
3245            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
3246            if os.path.exists(log_path):
3247                os.remove(log_path)
3248
3249    @abstractmethod
3250    def _reformat_results(res: dict, classes: dict, exclusive=False):
3251        """Add classes to micro metrics in results from evaluation"""
3252        results = deepcopy(res)
3253        for key in results.keys():
3254            if isinstance(results[key], list):
3255                if exclusive and len(classes) == len(results[key]) + 1:
3256                    other_ind = list(classes.keys())[
3257                        list(classes.values()).index("other")
3258                    ]
3259                    classes = {
3260                        (i if i < other_ind else i - 1): c
3261                        for i, c in classes.items()
3262                        if i != other_ind
3263                    }
3264                assert len(results[key]) == len(
3265                    classes
3266                ), f"Results for {key} have {len(results[key])} values, but {len(classes)} classes were provided!"
3267                results[key] = {
3268                    classes[i]: float(v) for i, v in enumerate(results[key])
3269                }
3270        return results
3271
3272    def prune_unfinished(self, exceptions: List = None) -> List:
3273        """Remove all interrupted episodes.
3274
3275        Remove all episodes that either don't have a log file or have less epochs in the log file than in
3276        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
3277        currently running!
3278
3279        Parameters
3280        ----------
3281        exceptions : list
3282            the episodes to keep even if they are interrupted
3283
3284        Returns
3285        -------
3286        pruned : list
3287            a list of the episode names that were pruned
3288
3289        """
3290        if exceptions is None:
3291            exceptions = []
3292        unfinished = self._episodes().unfinished_episodes()
3293        unfinished = [x for x in unfinished if x not in exceptions]
3294        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
3295        unfinished += [
3296            x for x in model_folders if x not in self._episodes().list_episodes().index
3297        ]
3298        print(f"PRUNING {unfinished}")
3299        for episode_name in unfinished:
3300            self.remove_episode(episode_name)
3301        print(f"\n")
3302        return unfinished
3303
3304    def prediction_path(self, prediction_name: str) -> str:
3305        """Get the path where prediction files are saved.
3306
3307        Parameters
3308        ----------
3309        prediction_name : str
3310            name of the prediction
3311
3312        Returns
3313        -------
3314        prediction_path : str
3315            the file path
3316
3317        """
3318        return os.path.join(
3319            self.project_path, "results", "predictions", f"{prediction_name}"
3320        )
3321
3322    def suggestion_path(self, suggestion_name: str) -> str:
3323        """Get the path where suggestion files are saved.
3324
3325        Parameters
3326        ----------
3327        suggestion_name : str
3328            name of the prediction
3329
3330        Returns
3331        -------
3332        suggestion_path : str
3333            the file path
3334
3335        """
3336        return os.path.join(
3337            self.project_path, "results", "suggestions", f"{suggestion_name}"
3338        )
3339
3340    @classmethod
3341    def print_data_types(cls):
3342        """Print available data types."""
3343        print("DATA TYPES:")
3344        for key, value in cls.data_types().items():
3345            print(f"{key}:")
3346            print(value.__doc__)
3347
3348    @classmethod
3349    def print_annotation_types(cls):
3350        """Print available annotation types."""
3351        print("ANNOTATION TYPES:")
3352        for key, value in cls.annotation_types():
3353            print(f"{key}:")
3354            print(value.__doc__)
3355
3356    @staticmethod
3357    def data_types() -> List:
3358        """Get available data types.
3359
3360        Returns
3361        -------
3362        data_types : list
3363            available data types
3364
3365        """
3366        return options.input_stores
3367
3368    @staticmethod
3369    def annotation_types() -> List:
3370        """Get available annotation types.
3371
3372        Returns
3373        -------
3374        list
3375            available annotation types
3376
3377        """
3378        return options.annotation_stores
3379
3380    def _save_mask(self, file: Dict, mask_name: str):
3381        """Save a mask file.
3382
3383        Parameters
3384        ----------
3385        file : dict
3386            the mask file data to save
3387        mask_name : str
3388            the name of the mask file
3389
3390        """
3391        if not os.path.exists(self._mask_path()):
3392            os.mkdir(self._mask_path())
3393        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f:
3394            pickle.dump(file, f)
3395
3396    def _load_mask(self, mask_name: str) -> Dict:
3397        """Load a mask file.
3398
3399        Parameters
3400        ----------
3401        mask_name : str
3402            the name of the mask file to load
3403
3404        Returns
3405        -------
3406        mask : dict
3407            the loaded mask data
3408
3409        """
3410        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f:
3411            data = pickle.load(f)
3412        return data
3413
3414    def _thresholds(self) -> DecisionThresholds:
3415        """Get the decision thresholds meta object.
3416
3417        Returns
3418        -------
3419        thresholds : DecisionThresholds
3420            the decision thresholds meta object
3421
3422        """
3423        return DecisionThresholds(self._thresholds_path())
3424
3425    def _episodes(self) -> SavedRuns:
3426        """Get the episodes meta object.
3427
3428        Returns
3429        -------
3430        episodes : SavedRuns
3431            the episodes meta object
3432
3433        """
3434        try:
3435            return SavedRuns(self._episodes_path(), self.project_path)
3436        except:
3437            self.load_metadata_backup()
3438            return SavedRuns(self._episodes_path(), self.project_path)
3439
3440    def _suggestions(self) -> Suggestions:
3441        """Get the suggestions meta object.
3442
3443        Returns
3444        -------
3445        suggestions : Suggestions
3446            the suggestions meta object
3447
3448        """
3449        try:
3450            return Suggestions(self._suggestions_path(), self.project_path)
3451        except:
3452            self.load_metadata_backup()
3453            return Suggestions(self._suggestions_path(), self.project_path)
3454
3455    def _predictions(self) -> SavedRuns:
3456        """Get the predictions meta object.
3457
3458        Returns
3459        -------
3460        predictions : SavedRuns
3461            the predictions meta object
3462
3463        """
3464        try:
3465            return SavedRuns(self._predictions_path(), self.project_path)
3466        except:
3467            self.load_metadata_backup()
3468            return SavedRuns(self._predictions_path(), self.project_path)
3469
3470    def _saved_datasets(self) -> SavedStores:
3471        """Get the datasets meta object.
3472
3473        Returns
3474        -------
3475        datasets : SavedStores
3476            the datasets meta object
3477
3478        """
3479        try:
3480            return SavedStores(self._saved_datasets_path())
3481        except:
3482            self.load_metadata_backup()
3483            return SavedStores(self._saved_datasets_path())
3484
3485    def _prediction(self, name: str) -> Run:
3486        """Get a prediction meta object.
3487
3488        Parameters
3489        ----------
3490        name : str
3491            episode name
3492
3493        Returns
3494        -------
3495        prediction : Run
3496            the prediction meta object
3497
3498        """
3499        try:
3500            return Run(name, self.project_path, meta_path=self._predictions_path())
3501        except:
3502            self.load_metadata_backup()
3503            return Run(name, self.project_path, meta_path=self._predictions_path())
3504
3505    def _episode(self, name: str) -> Run:
3506        """Get an episode meta object.
3507
3508        Parameters
3509        ----------
3510        name : str
3511            episode name
3512
3513        Returns
3514        -------
3515        episode : Run
3516            the episode meta object
3517
3518        """
3519        try:
3520            return Run(name, self.project_path, meta_path=self._episodes_path())
3521        except:
3522            self.load_metadata_backup()
3523            return Run(name, self.project_path, meta_path=self._episodes_path())
3524
3525    def _searches(self) -> Searches:
3526        """Get the hyperparameter search meta object.
3527
3528        Returns
3529        -------
3530        searches : Searches
3531            the searches meta object
3532
3533        """
3534        try:
3535            return Searches(self._searches_path(), self.project_path)
3536        except:
3537            self.load_metadata_backup()
3538            return Searches(self._searches_path(), self.project_path)
3539
3540    def _update_configs(self) -> None:
3541        """Update the project config files with newly added files and parameters.
3542
3543        This method updates the project configuration with the data path and copies
3544        any new configuration files from the original package to the project.
3545
3546        """
3547        self.update_parameters({"data": {"data_path": self.data_path}})
3548        folders = ["augmentations", "features", "model"]
3549        original_path = os.path.join(
3550            os.path.dirname(os.path.dirname(__file__)), "config"
3551        )
3552        project_path = os.path.join(self.project_path, "config")
3553        filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")]
3554        for folder in folders:
3555            filenames += [
3556                os.path.join(folder, x)
3557                for x in os.listdir(os.path.join(original_path, folder))
3558            ]
3559        filenames.append(os.path.join("data", f"{self.data_type}.yaml"))
3560        if self.annotation_type != "none":
3561            filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml"))
3562        for file in filenames:
3563            filepath_original = os.path.join(original_path, file)
3564            if file.startswith("data") or file.startswith("annotation"):
3565                file = os.path.basename(file)
3566            filepath_project = os.path.join(project_path, file)
3567            if not os.path.exists(filepath_project):
3568                shutil.copy(filepath_original, filepath_project)
3569            else:
3570                original_pars = self._open_yaml(filepath_original)
3571                project_pars = self._open_yaml(filepath_project)
3572                to_remove = []
3573                for key, value in project_pars.items():
3574                    if key not in original_pars:
3575                        if key not in ["data_type", "annotation_type"]:
3576                            to_remove.append(key)
3577                for key in to_remove:
3578                    project_pars.pop(key)
3579                to_remove = []
3580                for key, value in original_pars.items():
3581                    if key in project_pars:
3582                        to_remove.append(key)
3583                for key in to_remove:
3584                    original_pars.pop(key)
3585                project_pars = self._update(project_pars, original_pars)
3586                with open(filepath_project, "w", encoding="utf-8") as f:
3587                    YAML().dump(project_pars, f)
3588
3589    def _update_project(self) -> None:
3590        """Update project files with the current version."""
3591        version_file = self._version_path()
3592        ok = True
3593        if not os.path.exists(version_file):
3594            ok = False
3595        else:
3596            with open(version_file, encoding="utf-8") as f:
3597                project_version = f.read()
3598            if project_version < __version__:
3599                ok = False
3600            elif project_version > __version__:
3601                warnings.warn(
3602                    f"The project expects a higher dlc2action version ({project_version}), please update!"
3603                )
3604        if not ok:
3605            project_config_path = os.path.join(self.project_path, "config")
3606            config_path = os.path.join(
3607                os.path.dirname(os.path.dirname(__path__)), "config"
3608            )
3609            episodes = self._episodes()
3610            folders = ["annotation", "augmentations", "data", "features", "model"]
3611
3612            project_annotation_configs = os.listdir(
3613                os.path.join(project_config_path, "annotation")
3614            )
3615            annotation_configs = os.listdir(os.path.join(config_path, "annotation"))
3616            for ann_config in annotation_configs:
3617                if ann_config not in project_annotation_configs:
3618                    shutil.copytree(
3619                        os.path.join(config_path, "annotation", ann_config),
3620                        os.path.join(project_config_path, "annotation", ann_config),
3621                        dirs_exist_ok=True,
3622                    )
3623                else:
3624                    project_pars = self._open_yaml(
3625                        os.path.join(project_config_path, "annotation", ann_config)
3626                    )
3627                    pars = self._open_yaml(
3628                        os.path.join(config_path, "annotation", ann_config)
3629                    )
3630                    new_keys = set(pars.keys()) - set(project_pars.keys())
3631                    for key in new_keys:
3632                        project_pars[key] = pars[key]
3633                        c = self._get_comment(pars.ca.items.get(key))
3634                        project_pars.yaml_add_eol_comment(c, key=key)
3635                        episodes.update(
3636                            condition=f"general/annotation_type::={ann_config}",
3637                            update={f"data/{key}": pars[key]},
3638                        )
3639
3640    def _initialize_project(
3641        self,
3642        data_type: str,
3643        annotation_type: str = None,
3644        data_path: str = None,
3645        annotation_path: str = None,
3646        copy: bool = True,
3647    ) -> None:
3648        """Initialize a new project."""
3649        if data_type not in self.data_types():
3650            raise ValueError(
3651                f"The {data_type} data type is not available. "
3652                f"Please choose from {self.data_types()}"
3653            )
3654        if annotation_type not in self.annotation_types():
3655            raise ValueError(
3656                f"The {annotation_type} annotation type is not available. "
3657                f"Please choose from {self.annotation_types()}"
3658            )
3659        os.mkdir(self.project_path)
3660        folders = ["results", "saved_datasets", "meta", "config"]
3661        for f in folders:
3662            os.mkdir(os.path.join(self.project_path, f))
3663        results_subfolders = [
3664            "model",
3665            "logs",
3666            "predictions",
3667            "splits",
3668            "searches",
3669            "suggestions",
3670        ]
3671        for sf in results_subfolders:
3672            os.mkdir(os.path.join(self.project_path, "results", sf))
3673        if data_path is not None:
3674            if copy:
3675                os.mkdir(os.path.join(self.project_path, "data"))
3676                shutil.copytree(
3677                    data_path,
3678                    os.path.join(self.project_path, "data"),
3679                    dirs_exist_ok=True,
3680                )
3681                data_path = os.path.join(self.project_path, "data")
3682        if annotation_path is not None:
3683            if copy:
3684                os.mkdir(os.path.join(self.project_path, "annotation"))
3685                shutil.copytree(
3686                    annotation_path,
3687                    os.path.join(self.project_path, "annotation"),
3688                    dirs_exist_ok=True,
3689                )
3690                annotation_path = os.path.join(self.project_path, "annotation")
3691        self._generate_config(
3692            data_type,
3693            annotation_type,
3694            data_path=data_path,
3695            annotation_path=annotation_path,
3696        )
3697        self._generate_meta()
3698
3699    def _read_types(self) -> Tuple[str, str]:
3700        """Get data type and annotation type from existing project files."""
3701        config_path = os.path.join(self.project_path, "config", "general.yaml")
3702        with open(config_path, encoding="utf-8") as f:
3703            pars = YAML().load(f)
3704        data_type = pars["data_type"]
3705        annotation_type = pars["annotation_type"]
3706        return annotation_type, data_type
3707
3708    def _read_paths(self) -> Tuple[str, str]:
3709        """Get data type and annotation type from existing project files."""
3710        config_path = os.path.join(self.project_path, "config", "data.yaml")
3711        with open(config_path, encoding="utf-8") as f:
3712            pars = YAML().load(f)
3713        data_path = pars["data_path"]
3714        annotation_path = pars["annotation_path"]
3715        return annotation_path, data_path
3716
3717    def _generate_config(
3718        self, data_type: str, annotation_type: str, data_path: str, annotation_path: str
3719    ) -> None:
3720        """Initialize the config files."""
3721        default_path = os.path.join(
3722            os.path.dirname(os.path.dirname(__file__)), "config"
3723        )
3724        config_path = os.path.join(self.project_path, "config")
3725        files = ["losses", "metrics", "ssl", "training"]
3726        for f in files:
3727            shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path)
3728        shutil.copytree(
3729            os.path.join(default_path, "model"), os.path.join(config_path, "model")
3730        )
3731        shutil.copytree(
3732            os.path.join(default_path, "features"),
3733            os.path.join(config_path, "features"),
3734        )
3735        shutil.copytree(
3736            os.path.join(default_path, "augmentations"),
3737            os.path.join(config_path, "augmentations"),
3738        )
3739        yaml = YAML()
3740        data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml")
3741        data_params = None
3742        if os.path.exists(data_param_path):
3743            with open(data_param_path, encoding="utf-8") as f:
3744                data_params = yaml.load(f)
3745        if data_params is None:
3746            data_params = {}
3747        if annotation_type is None:
3748            ann_params = {}
3749        else:
3750            ann_param_path = os.path.join(
3751                default_path, "annotation", f"{annotation_type}.yaml"
3752            )
3753            if os.path.exists(ann_param_path):
3754                ann_params = self._open_yaml(ann_param_path)
3755            elif annotation_type == "none":
3756                ann_params = {}
3757            else:
3758                raise ValueError(
3759                    f"The {annotation_type} data type is not available. "
3760                    f"Please choose from {BehaviorDataset.annotation_types()}"
3761                )
3762        if ann_params is None:
3763            ann_params = {}
3764        data_params = self._update(data_params, ann_params)
3765        data_params["data_path"] = data_path
3766        data_params["annotation_path"] = annotation_path
3767        with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f:
3768            yaml.dump(data_params, f)
3769        with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f:
3770            general_params = yaml.load(f)
3771        general_params["data_type"] = data_type
3772        general_params["annotation_type"] = annotation_type
3773        with open(
3774            os.path.join(config_path, "general.yaml"), "w", encoding="utf-8"
3775        ) as f:
3776            yaml.dump(general_params, f)
3777
3778    def _generate_meta(self) -> None:
3779        """Initialize the meta files."""
3780        config_file = os.path.join(self.project_path, "config")
3781        meta_fields = ["time"]
3782        columns = [("meta", field) for field in meta_fields]
3783        episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3784        episodes.to_pickle(self._episodes_path())
3785        meta_fields = ["time", "objective"]
3786        result_fields = ["best_params", "best_value"]
3787        columns = [("meta", field) for field in meta_fields] + [
3788            ("results", field) for field in result_fields
3789        ]
3790        searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3791        searches.to_pickle(self._searches_path())
3792        meta_fields = ["time"]
3793        columns = [("meta", field) for field in meta_fields]
3794        predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3795        predictions.to_pickle(self._predictions_path())
3796        suggestions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3797        suggestions.to_pickle(self._suggestions_path())
3798        with open(os.path.join(config_file, "data.yaml"), encoding="utf-8") as f:
3799            data_keys = list(YAML().load(f).keys())
3800        saved_data = pd.DataFrame(columns=data_keys)
3801        saved_data.to_pickle(self._saved_datasets_path())
3802        pd.DataFrame().to_pickle(self._thresholds_path())
3803        # with open(self._version_path()) as f:
3804        #     f.write(__version__)
3805
3806    def _open_yaml(self, path: str) -> CommentedMap:
3807        """Load a parameter dictionary from a .yaml file."""
3808        with open(path, encoding="utf-8") as f:
3809            data = YAML().load(f)
3810        if data is None:
3811            data = {}
3812        return data
3813
3814    def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7):
3815        """Compare nested dictionaries with 'almost equal' condition."""
3816        ok = True
3817        if u.keys() != d.keys():
3818            ok = False
3819        else:
3820            for k, v in u.items():
3821                if isinstance(v, Mapping):
3822                    ok = self._compare(d[k], v, allow_diff=allow_diff)
3823                else:
3824                    if isinstance(v, float) or isinstance(d[k], float):
3825                        if not isinstance(d[k], float) and not isinstance(d[k], int):
3826                            ok = False
3827                        elif not isinstance(v, float) and not isinstance(v, int):
3828                            ok = False
3829                        elif np.abs(v - d[k]) > allow_diff:
3830                            ok = False
3831                    elif v != d[k]:
3832                        ok = False
3833        return ok
3834
3835    def _check_comment(self, comment_sequence: List) -> bool:
3836        """Check if a comment already exists in a ruamel.yaml comment sequence."""
3837        if comment_sequence is None:
3838            return False
3839        c = self._get_comment(comment_sequence)
3840        if c != "":
3841            return True
3842        else:
3843            return False
3844
3845    def _get_comment(self, comment_sequence: List, strip=True) -> str:
3846        """Get the comment string from a ruamel.yaml comment sequence."""
3847        if comment_sequence is None:
3848            return ""
3849        c = ""
3850        for cm in comment_sequence:
3851            if cm is not None:
3852                if isinstance(cm, Iterable):
3853                    for c in cm:
3854                        if c is not None:
3855                            c = c.value
3856                            break
3857                    break
3858                else:
3859                    c = cm.value
3860                    break
3861        if strip:
3862            c = c.strip()
3863        return c
3864
3865    def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]):
3866        """Update a nested dictionary."""
3867        if "general" in u and "model_name" in u["general"] and "model" in d:
3868            model_name = u["general"]["model_name"]
3869            if d["general"]["model_name"] != model_name:
3870                d["model"] = self._open_yaml(
3871                    os.path.join(
3872                        self.project_path, "config", "model", f"{model_name}.yaml"
3873                    )
3874                )
3875        d_copied = deepcopy(d)
3876        for k, v in u.items():
3877            if (
3878                k in d_copied
3879                and isinstance(d_copied[k], list)
3880                and isinstance(v, Mapping)
3881                and all([isinstance(x, int) for x in v.keys()])
3882            ):
3883                for kk, vv in v.items():
3884                    d_copied[k][kk] = vv
3885            elif (
3886                isinstance(v, Mapping)
3887                and k in d_copied
3888                and isinstance(d_copied[k], Mapping)
3889            ):
3890                if d_copied[k] is None:
3891                    d_k = CommentedMap()
3892                else:
3893                    d_k = d_copied[k]
3894                d_copied[k] = self._update(d_k, v)
3895            else:
3896                d_copied[k] = v
3897                if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None:
3898                    c = self._get_comment(u.ca.items.get(k), strip=False)
3899                    if isinstance(d_copied, CommentedMap) and not self._check_comment(
3900                        d_copied.ca.items.get(k)
3901                    ):
3902                        d_copied.yaml_add_eol_comment(c, key=k)
3903        return d_copied
3904
3905    def _update_with_search(
3906        self,
3907        d: Dict,
3908        search_name: str,
3909        load_parameters: list = None,
3910        round_to_binary: list = None,
3911    ):
3912        """Update a dictionary with best parameters from a hyperparameter search."""
3913        u, _ = self._searches().get_best_params(
3914            search_name, load_parameters, round_to_binary
3915        )
3916        return self._update(d, u)
3917
3918    def _read_parameters(self, catch_blanks=True) -> Dict:
3919        """Compose a parameter dictionary to create a task from the config files."""
3920        config_path = os.path.join(self.project_path, "config")
3921        keys = [
3922            "data",
3923            "general",
3924            "losses",
3925            "metrics",
3926            "ssl",
3927            "training",
3928        ]
3929        parameters = {}
3930        for key in keys:
3931            parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml"))
3932        features = parameters["general"]["feature_extraction"]
3933        parameters["features"] = self._open_yaml(
3934            os.path.join(config_path, "features", f"{features}.yaml")
3935        )
3936        transformer = options.extractor_to_transformer[features]
3937        parameters["augmentations"] = self._open_yaml(
3938            os.path.join(config_path, "augmentations", f"{transformer}.yaml")
3939        )
3940        model = parameters["general"]["model_name"]
3941        parameters["model"] = self._open_yaml(
3942            os.path.join(config_path, "model", f"{model}.yaml")
3943        )
3944        # input = parameters["general"]["input"]
3945        # parameters["model"] = self._open_yaml(
3946        #     os.path.join(config_path, "model", f"{model}.yaml")
3947        # )
3948        if catch_blanks:
3949            blanks = self._get_blanks()
3950            if len(blanks) > 0:
3951                self.list_blanks()
3952                raise ValueError(
3953                    f"Please fill in all the blanks before running experiments"
3954                )
3955        return parameters
3956
3957    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3958        """Select the model and the metrics.
3959
3960        Parameters
3961        ----------
3962        model_name : str, optional
3963            model name; run `project.help("model") to find out more
3964        metric_names : list, optional
3965            a list of metric function names; run `project.help("metrics") to find out more
3966
3967        """
3968        pars = {"general": {}}
3969        if model_name is not None:
3970            assert model_name in options.models
3971            pars["general"]["model_name"] = model_name
3972        if metric_names is not None:
3973            for metric in metric_names:
3974                assert metric in options.metrics
3975            pars["general"]["metric_functions"] = metric_names
3976        self.update_parameters(pars)
3977
3978    def help(self, keyword: str = None):
3979        """Get information on available options.
3980
3981        Parameters
3982        ----------
3983        keyword : str, optional
3984            the keyword for options (run without arguments to see which keywords are available)
3985
3986        """
3987        if keyword is None:
3988            print("AVAILABLE HELP FUNCTIONS:")
3989            print("- Try running `project.help(keyword)` with the following keywords:")
3990            print("    - model: to get more information on available models,")
3991            print(
3992                "    - features: to get more information on available feature extraction modes,"
3993            )
3994            print(
3995                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3996            )
3997            print("    - metrics: to see a list of available metric functions.")
3998            print("    - data: to see help for expected data structure")
3999            print(
4000                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
4001            )
4002            print(
4003                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
4004            )
4005            print(
4006                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
4007            )
4008        elif keyword == "model":
4009            print("MODELS:")
4010            for key, model in options.models.items():
4011                print(f"{key}:")
4012                print(model.__doc__)
4013        elif keyword == "features":
4014            print("FEATURE EXTRACTORS:")
4015            for key, extractor in options.feature_extractors.items():
4016                print(f"{key}:")
4017                print(extractor.__doc__)
4018        elif keyword == "partition_method":
4019            print("PARTITION METHODS:")
4020            print(
4021                BehaviorDataset.partition_train_test_val.__doc__.split(
4022                    "The partitioning method:"
4023                )[1].split("val_frac :")[0]
4024            )
4025        elif keyword == "metrics":
4026            print("METRICS:")
4027            for key, metric in options.metrics.items():
4028                print(f"{key}:")
4029                print(metric.__doc__)
4030        elif keyword == "data":
4031            print("DATA:")
4032            print(f"Video data: {self.data_type}")
4033            print(options.input_stores[self.data_type].__doc__)
4034            print(f"Annotation data: {self.annotation_type}")
4035            print(options.annotation_stores[self.annotation_type].__doc__)
4036            print(
4037                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
4038            )
4039        else:
4040            raise ValueError(f"The {keyword} keyword is not recognized")
4041        print("\n")
4042
4043    def _process_value(self, value):
4044        """Process a configuration value for display.
4045
4046        Parameters
4047        ----------
4048        value : any
4049            the value to process
4050
4051        Returns
4052        -------
4053        processed_value : any
4054            the processed value
4055
4056        """
4057        if isinstance(value, str):
4058            value = f'"{value}"'
4059        elif isinstance(value, CommentedSet):
4060            value = {x for x in value}
4061        return value
4062
4063    def _get_blanks(self):
4064        """Get a list of blank (unset) parameters in the configuration.
4065
4066        Returns
4067        -------
4068        caught : list
4069            a list of parameter keys that have blank values
4070
4071        """
4072        caught = []
4073        parameters = self._read_parameters(catch_blanks=False)
4074        for big_key, big_value in parameters.items():
4075            for key, value in big_value.items():
4076                if value == "???":
4077                    caught.append(
4078                        (big_key, key, self._get_comment(big_value.ca.items.get(key)))
4079                    )
4080        return caught
4081
4082    def list_blanks(self, blanks=None):
4083        """List parameters that need to be filled in.
4084
4085        Parameters
4086        ----------
4087        blanks : list, optional
4088            a list of the parameters to list, if already known
4089
4090        """
4091        if blanks is None:
4092            blanks = self._get_blanks()
4093        if len(blanks) > 0:
4094            to_update = defaultdict(lambda: [])
4095            for b, k, c in blanks:
4096                to_update[b].append((k, c))
4097            print("Before running experiments, please update all the blanks.")
4098            print("To do that, you can run this.")
4099            print("--------------------------------------------------------")
4100            print(f"project.update_parameters(")
4101            print(f"    {{")
4102            for big_key, keys in to_update.items():
4103                print(f'        "{big_key}": {{')
4104                for key, comment in keys:
4105                    print(f'            "{key}": ..., {comment}')
4106                print(f"        }}")
4107            print(f"    }}")
4108            print(")")
4109            print("--------------------------------------------------------")
4110            print("Replace ... with relevant values.")
4111        else:
4112            print("There is no blanks left!")
4113
4114    def list_basic_parameters(
4115        self,
4116    ):
4117        """Get a list of most relevant parameters and code to modify them."""
4118        parameters = self._read_parameters()
4119        print("BASIC PARAMETERS:")
4120        model_name = parameters["general"]["model_name"]
4121        metric_names = parameters["general"]["metric_functions"]
4122        loss_name = parameters["general"]["loss_function"]
4123        feature_extraction = parameters["general"]["feature_extraction"]
4124        print("Here is a list of current parameters.")
4125        print(
4126            "You can copy this code, change the parameters you want to set and run it to update the project config."
4127        )
4128        print("--------------------------------------------------------")
4129        print("project.update_parameters(")
4130        print("    {")
4131        for group in ["general", "data", "training"]:
4132            print(f'        "{group}": {{')
4133            for key in options.basic_parameters[group]:
4134                if key in parameters[group]:
4135                    print(
4136                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
4137                    )
4138            print("        },")
4139        print('        "losses": {')
4140        print(f'            "{loss_name}": {{')
4141        for key in options.basic_parameters["losses"][loss_name]:
4142            if key in parameters["losses"][loss_name]:
4143                print(
4144                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
4145                )
4146        print("            },")
4147        print("        },")
4148        print('        "metrics": {')
4149        for metric in metric_names:
4150            print(f'            "{metric}": {{')
4151            for key in parameters["metrics"][metric]:
4152                print(
4153                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
4154                )
4155            print("            },")
4156        print("        },")
4157        print('        "model": {')
4158        for key in options.basic_parameters["model"][model_name]:
4159            if key in parameters["model"]:
4160                print(
4161                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
4162                )
4163
4164        print("        },")
4165        print('        "features": {')
4166        for key in options.basic_parameters["features"][feature_extraction]:
4167            if key in parameters["features"]:
4168                print(
4169                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
4170                )
4171
4172        print("        },")
4173        print('        "augmentations": {')
4174        for key in options.basic_parameters["augmentations"][feature_extraction]:
4175            if key in parameters["augmentations"]:
4176                print(
4177                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
4178                )
4179        print("        },")
4180        print("    },")
4181        print(")")
4182        print("--------------------------------------------------------")
4183        print("\n")
4184
4185    def _create_record(
4186        self,
4187        episode_name: str,
4188        behaviors_dict: Dict,
4189        load_episode: str = None,
4190        parameters_update: Dict = None,
4191        task: TaskDispatcher = None,
4192        load_epoch: int = None,
4193        load_search: str = None,
4194        load_parameters: list = None,
4195        round_to_binary: list = None,
4196        load_strict: bool = True,
4197        n_seeds: int = 1,
4198    ) -> TaskDispatcher:
4199        """Create a meta data episode record."""
4200        if episode_name in self._episodes().data.index:
4201            return
4202        if type(n_seeds) is not int or n_seeds < 1:
4203            raise ValueError(
4204                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
4205            )
4206        if parameters_update is None:
4207            parameters_update = {}
4208        parameters = self._read_parameters()
4209        parameters = self._update(parameters, parameters_update)
4210        if load_search is not None:
4211            parameters = self._update_with_search(
4212                parameters, load_search, load_parameters, round_to_binary
4213            )
4214        parameters = self._fill(
4215            parameters,
4216            episode_name,
4217            load_episode,
4218            load_epoch=load_epoch,
4219            only_load_model=True,
4220            load_strict=load_strict,
4221            continuing=True,
4222        )
4223        self._save_episode(episode_name, parameters, behaviors_dict)
4224        return task
4225
4226    def _save_thresholds(
4227        self,
4228        episode_names: List,
4229        metric_name: str,
4230        parameters: Dict,
4231        thresholds: List,
4232        load_epochs: List,
4233    ):
4234        """Save optimal decision thresholds in the meta records."""
4235        metric_parameters = parameters["metrics"][metric_name]
4236        self._thresholds().save_thresholds(
4237            episode_names, load_epochs, metric_name, metric_parameters, thresholds
4238        )
4239
4240    def _save_episode(
4241        self,
4242        episode_name: str,
4243        parameters: Dict,
4244        behaviors_dict: Dict,
4245        suppress_validation: bool = False,
4246        training_time: str = None,
4247        norm_stats: Dict = None,
4248    ) -> None:
4249        """Save an episode in the meta files."""
4250        try:
4251            split_info = self._split_info_from_filename(
4252                parameters["training"]["split_path"]
4253            )
4254            parameters["training"]["partition_method"] = split_info["partition_method"]
4255        except:
4256            pass
4257        if norm_stats is not None:
4258            norm_stats = dict(norm_stats)
4259        parameters["training"]["stats"] = norm_stats
4260        self._episodes().save_episode(
4261            episode_name,
4262            parameters,
4263            behaviors_dict,
4264            suppress_validation=suppress_validation,
4265            training_time=training_time,
4266        )
4267
4268    def _save_suggestions(
4269        self, suggestions_name: str, parameters: Dict, meta_parameters: Dict
4270    ) -> None:
4271        """Save a suggestion in the meta files."""
4272        self._suggestions().save_suggestion(
4273            suggestions_name, parameters, meta_parameters
4274        )
4275
4276    def _update_episode_results(
4277        self,
4278        episode_name: str,
4279        logs: Tuple,
4280        training_time: str = None,
4281    ) -> None:
4282        """Save the results of a run in the meta files."""
4283        self._episodes().update_episode_results(episode_name, logs, training_time)
4284
4285    def _save_prediction(
4286        self,
4287        prediction_name: str,
4288        predicted: Dict[str, Dict],
4289        parameters: Dict,
4290        task: TaskDispatcher,
4291        mode: str = "test",
4292        embedding: bool = False,
4293        inference_time: str = None,
4294        behavior_dict: List[Dict[str, Any]] = None,
4295    ) -> None:
4296        """Save a prediction in the meta files."""
4297
4298        folder = self.prediction_path(prediction_name)
4299        os.mkdir(folder)
4300        for video_id, prediction in predicted.items():
4301            with open(
4302                os.path.join(
4303                    folder, video_id + f"_{prediction_name}_prediction.pickle"
4304                ),
4305                "wb",
4306            ) as f:
4307                prediction["min_frames"], prediction["max_frames"] = task.dataset(
4308                    mode
4309                ).get_min_max_frames(video_id)
4310                prediction["classes"] = behavior_dict
4311                pickle.dump(prediction, f)
4312
4313        parameters = self._update(
4314            parameters,
4315            {"meta": {"embedding": embedding, "inference_time": inference_time}},
4316        )
4317        self._predictions().save_episode(
4318            prediction_name, parameters, task.behaviors_dict()
4319        )
4320
4321    def _save_search(
4322        self,
4323        search_name: str,
4324        parameters: Dict,
4325        n_trials: int,
4326        best_params: Dict,
4327        best_value: float,
4328        metric: str,
4329        search_space: Dict,
4330    ) -> None:
4331        """Save a hyperparameter search in the meta files."""
4332        self._searches().save_search(
4333            search_name,
4334            parameters,
4335            n_trials,
4336            best_params,
4337            best_value,
4338            metric,
4339            search_space,
4340        )
4341
4342    def _save_stores(self, parameters: Dict) -> None:
4343        """Save a pickled dataset in the meta files."""
4344        name = os.path.basename(parameters["data"]["feature_save_path"])
4345        self._saved_datasets().save_store(name, self._get_data_pars(parameters))
4346        self.create_metadata_backup()
4347
4348    def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None:
4349        """Remove the pre-computed features folder."""
4350        name = os.path.basename(parameters["data"]["feature_save_path"])
4351        if remove_active or name not in self._episodes().get_active_datasets():
4352            self.remove_saved_features([name])
4353
4354    def _check_episode_validity(
4355        self, episode_name: str, allow_doublecolon: bool = False, force: bool = False
4356    ) -> None:
4357        """Check whether the episode name is valid."""
4358        if episode_name.startswith("_"):
4359            raise ValueError(
4360                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4361            )
4362        elif "." in episode_name:
4363            raise ValueError("Names containing '.' cannot be used!")
4364        if not allow_doublecolon and "#" in episode_name:
4365            raise ValueError(
4366                "Names containing '#' are reserved by dlc2action and cannot be used!"
4367            )
4368        if "::" in episode_name:
4369            raise ValueError(
4370                "Names containing '::' are reserved by dlc2action and cannot be used!"
4371            )
4372        if force:
4373            self.remove_episode(episode_name)
4374        elif not self._episodes().check_name_validity(episode_name):
4375            raise ValueError(
4376                f"The {episode_name} name is already taken! Set force=True to overwrite."
4377            )
4378
4379    def _check_search_validity(self, search_name: str, force: bool = False) -> None:
4380        """Check whether the search name is valid."""
4381        if search_name.startswith("_"):
4382            raise ValueError(
4383                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4384            )
4385        elif "." in search_name:
4386            raise ValueError("Names containing '.' cannot be used!")
4387        if force:
4388            self.remove_search(search_name)
4389        elif not self._searches().check_name_validity(search_name):
4390            raise ValueError(f"The {search_name} name is already taken!")
4391
4392    def _check_prediction_validity(
4393        self, prediction_name: str, force: bool = False
4394    ) -> None:
4395        """Check whether the prediction name is valid."""
4396        if prediction_name.startswith("_"):
4397            raise ValueError(
4398                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4399            )
4400        elif "." in prediction_name:
4401            raise ValueError("Names containing '.' cannot be used!")
4402        if force:
4403            self.remove_prediction(prediction_name)
4404        elif not self._predictions().check_name_validity(prediction_name):
4405            raise ValueError(f"The {prediction_name} name is already taken!")
4406
4407    def _check_suggestions_validity(
4408        self, suggestions_name: str, force: bool = False
4409    ) -> None:
4410        """Check whether the suggestions name is valid."""
4411        if suggestions_name.startswith("_"):
4412            raise ValueError(
4413                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4414            )
4415        elif "." in suggestions_name:
4416            raise ValueError("Names containing '.' cannot be used!")
4417        if force:
4418            self.remove_suggestion(suggestions_name)
4419        elif not self._suggestions().check_name_validity(suggestions_name):
4420            raise ValueError(f"The {suggestions_name} name is already taken!")
4421
4422    def _training_time(self, episode_name: str) -> int:
4423        """Get the training time of an episode in seconds."""
4424        return self._episode(episode_name).training_time()
4425
4426    def _mask_path(self) -> str:
4427        """Get the path to the masks folder.
4428
4429        Returns
4430        -------
4431        path : str
4432            the path to the masks folder
4433
4434        """
4435        return os.path.join(self.project_path, "results", "masks")
4436
4437    def _thresholds_path(self) -> str:
4438        """Get the path to the thresholds meta file.
4439
4440        Returns
4441        -------
4442        path : str
4443            the path to the thresholds meta file
4444
4445        """
4446        return os.path.join(self.project_path, "meta", "thresholds.pickle")
4447
4448    def _episodes_path(self) -> str:
4449        """Get the path to the episodes meta file.
4450
4451        Returns
4452        -------
4453        path : str
4454            the path to the episodes meta file
4455
4456        """
4457        return os.path.join(self.project_path, "meta", "episodes.pickle")
4458
4459    def _suggestions_path(self) -> str:
4460        """Get the path to the suggestions meta file.
4461
4462        Returns
4463        -------
4464        path : str
4465            the path to the suggestions meta file
4466
4467        """
4468        return os.path.join(self.project_path, "meta", "suggestions.pickle")
4469
4470    def _saved_datasets_path(self) -> str:
4471        """Get the path to the datasets meta file.
4472
4473        Returns
4474        -------
4475        path : str
4476            the path to the datasets meta file
4477
4478        """
4479        return os.path.join(self.project_path, "meta", "saved_datasets.pickle")
4480
4481    def _predictions_path(self) -> str:
4482        """Get the path to the predictions meta file.
4483
4484        Returns
4485        -------
4486        path : str
4487            the path to the predictions meta file
4488
4489        """
4490        return os.path.join(self.project_path, "meta", "predictions.pickle")
4491
4492    def _dataset_store_path(self, name: str) -> str:
4493        """Get the path to a specific pickled dataset.
4494
4495        Parameters
4496        ----------
4497        name : str
4498            the name of the dataset
4499
4500        Returns
4501        -------
4502        path : str
4503            the path to the dataset file
4504
4505        """
4506        return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle")
4507
4508    def _al_points_path(self, suggestions_name: str) -> str:
4509        """Get the path to an active learning intervals file.
4510
4511        Parameters
4512        ----------
4513        suggestions_name : str
4514            the name of the suggestions
4515
4516        Returns
4517        -------
4518        path : str
4519            the path to the active learning points file
4520
4521        """
4522        path = os.path.join(
4523            self.project_path,
4524            "results",
4525            "suggestions",
4526            suggestions_name,
4527            f"{suggestions_name}_al_points.pickle",
4528        )
4529        return path
4530
4531    def _suggestion_path(self, v_id: str, suggestions_name: str) -> str:
4532        """Get the path to a suggestion file.
4533
4534        Parameters
4535        ----------
4536        v_id : str
4537            the video ID
4538        suggestions_name : str
4539            the name of the suggestions
4540
4541        Returns
4542        -------
4543        path : str
4544            the path to the suggestion file
4545
4546        """
4547        path = os.path.join(
4548            self.project_path,
4549            "results",
4550            "suggestions",
4551            suggestions_name,
4552            f"{v_id}_suggestion.pickle",
4553        )
4554        return path
4555
4556    def _searches_path(self) -> str:
4557        """Get the path to the hyperparameter search meta file.
4558
4559        Returns
4560        -------
4561        path : str
4562            the path to the searches meta file
4563
4564        """
4565        return os.path.join(self.project_path, "meta", "searches.pickle")
4566
4567    def _search_path(self, name: str) -> str:
4568        """Get the default path to the graph folder for a specific hyperparameter search.
4569
4570        Parameters
4571        ----------
4572        name : str
4573            the name of the search
4574
4575        Returns
4576        -------
4577        path : str
4578            the path to the search folder
4579
4580        """
4581        return os.path.join(self.project_path, "results", "searches", name)
4582
4583    def _version_path(self) -> str:
4584        """Get the path to the version file.
4585
4586        Returns
4587        -------
4588        path : str
4589            the path to the version file
4590
4591        """
4592        return os.path.join(self.project_path, "meta", "version.txt")
4593
4594    def _default_split_file(self, split_info: Dict) -> Optional[str]:
4595        """Generate a path to a split file from split parameters.
4596
4597        Parameters
4598        ----------
4599        split_info : dict
4600            the split parameters dictionary
4601
4602        Returns
4603        -------
4604        split_file_path : str or None
4605            the path to the split file, or None if not applicable
4606
4607        """
4608        if split_info["partition_method"].startswith("time"):
4609            return None
4610        val_frac = split_info["val_frac"]
4611        test_frac = split_info["test_frac"]
4612        split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}'
4613        if not split_info["only_load_annotated"]:
4614            split_name += "_all"
4615        split_name += ".txt"
4616        return os.path.join(self.project_path, "results", "splits", split_name)
4617
4618    def _split_info_from_filename(self, split_name: str) -> Dict:
4619        """Get split parameters from default path to a split file.
4620
4621        Parameters
4622        ----------
4623        split_name : str
4624            the name/path of the split file
4625
4626        Returns
4627        -------
4628        split_info : dict
4629            the split parameters dictionary
4630
4631        """
4632        if split_name is None:
4633            return {}
4634        try:
4635            name = os.path.basename(split_name)[:-4]
4636            split = name.split("_")
4637            if len(split) == 6:
4638                only_load_annotated = False
4639            else:
4640                only_load_annotated = True
4641            len_segment = int(split[3][3:])
4642            overlap = float(split[4][7:])
4643            if overlap > 1:
4644                overlap = int(overlap)
4645            method, val, test = split[:3]
4646            val = float(val[3:-1]) / 100
4647            test = float(test[4:-1]) / 100
4648            return {
4649                "partition_method": method,
4650                "val_frac": val,
4651                "test_frac": test,
4652                "only_load_annotated": only_load_annotated,
4653                "len_segment": len_segment,
4654                "overlap": overlap,
4655            }
4656        except:
4657            return {"partition_method": "file"}
4658
4659    def _fill(
4660        self,
4661        parameters: Dict,
4662        episode_name: str,
4663        load_experiment: str = None,
4664        load_epoch: int = None,
4665        load_strict: bool = True,
4666        only_load_model: bool = False,
4667        continuing: bool = False,
4668        enforce_split_parameters: bool = False,
4669    ) -> Dict:
4670        """Update the parameters from the config files with project specific information.
4671
4672        Fill in the constant file path parameters and generate a unique log file and a model folder.
4673        Fill in the split file if the same split has been run before in the project and change partition method to
4674        from_file.
4675        Fill in saved data path if a dataset with the same data parameters already exists in the project.
4676        If load_experiment is not None, fill in the checkpoint path as well.
4677        The only_load_model training parameter is defined by the corresponding argument.
4678        If continuing is True, new files are not created and all information is loaded from load_experiment.
4679        If prediction is True, log and model files are not created.
4680        The enforce_split_parameters parameter is used to resolve conflicts
4681        between split file path and split parameters when they arise.
4682
4683        Parameters
4684        ----------
4685        parameters : dict
4686            the parameters dictionary to update
4687        episode_name : str
4688            the name of the episode
4689        load_experiment : str, optional
4690            the name of the experiment to load from
4691        load_epoch : int, optional
4692            the epoch to load (by default the last one)
4693        load_strict : bool, default True
4694            if `True`, strict loading is enforced
4695        only_load_model : bool, default False
4696            if `True`, only the model is loaded
4697        continuing : bool, default False
4698            if `True`, continues from existing files
4699        enforce_split_parameters : bool, default False
4700            if `True`, split parameters are enforced
4701
4702        Returns
4703        -------
4704        parameters : dict
4705            the updated parameters dictionary
4706
4707        """
4708        pars = deepcopy(parameters)
4709        if episode_name == "_":
4710            self.remove_episode("_")
4711        log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt")
4712        model_save_path = os.path.join(
4713            self.project_path, "results", "model", episode_name
4714        )
4715        if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)):
4716            raise ValueError(
4717                f"The {episode_name} episode name is already in use! Set force=True to overwrite."
4718            )
4719        keys = ["val_frac", "test_frac", "partition_method"]
4720        if "len_segment" not in pars["general"] and "len_segment" in pars["data"]:
4721            pars["general"]["len_segment"] = pars["data"]["len_segment"]
4722        if "overlap" not in pars["general"] and "overlap" in pars["data"]:
4723            pars["general"]["overlap"] = pars["data"]["overlap"]
4724        if "len_segment" in pars["data"]:
4725            pars["data"].pop("len_segment")
4726        if "overlap" in pars["data"]:
4727            pars["data"].pop("overlap")
4728        split_info = {k: pars["training"][k] for k in keys}
4729        split_info["only_load_annotated"] = pars["general"]["only_load_annotated"]
4730        split_info["len_segment"] = pars["general"]["len_segment"]
4731        split_info["overlap"] = pars["general"]["overlap"]
4732        pars["training"]["log_file"] = log
4733        if not os.path.exists(model_save_path):
4734            os.mkdir(model_save_path)
4735        pars["training"]["model_save_path"] = model_save_path
4736        if load_experiment is not None:
4737            if load_experiment not in self._episodes().data.index:
4738                raise ValueError(f"The {load_experiment} episode does not exist!")
4739            old_episode = self._episode(load_experiment)
4740            old_file = old_episode.split_file()
4741            old_info = self._split_info_from_filename(old_file)
4742            if len(old_info) == 0:
4743                old_info = old_episode.split_info()
4744            if enforce_split_parameters:
4745                if split_info["partition_method"] != "file":
4746                    pars["training"]["split_path"] = self._default_split_file(
4747                        split_info
4748                    )
4749            else:
4750                equal = True
4751                if old_info["partition_method"] != split_info["partition_method"]:
4752                    equal = False
4753                if old_info["partition_method"] != "file":
4754                    if (
4755                        old_info["val_frac"] != split_info["val_frac"]
4756                        or old_info["test_frac"] != split_info["test_frac"]
4757                    ):
4758                        equal = False
4759                if not continuing and not equal:
4760                    warnings.warn(
4761                        f"The partitioning parameters in the loaded experiment ({old_info}) "
4762                        f"are not equal to the current partitioning parameters ({split_info}). "
4763                        f"The current parameters are replaced."
4764                    )
4765                pars["training"]["split_path"] = old_file
4766                for k, v in old_info.items():
4767                    pars["training"][k] = v
4768            pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch)
4769            pars["training"]["load_strict"] = load_strict
4770        else:
4771            pars["training"]["checkpoint_path"] = None
4772            if pars["training"]["partition_method"] == "file":
4773                if (
4774                    "split_path" not in pars["training"]
4775                    or pars["training"]["split_path"] is None
4776                ):
4777                    raise ValueError(
4778                        "The partition_method parameter is set to file but the "
4779                        "split_path parameter is not set!"
4780                    )
4781                elif not os.path.exists(pars["training"]["split_path"]):
4782                    raise ValueError(
4783                        f'The {pars["training"]["split_path"]} split file does not exist'
4784                    )
4785            else:
4786                pars["training"]["split_path"] = self._default_split_file(split_info)
4787        pars["training"]["only_load_model"] = only_load_model
4788        pars["data"]["saved_data_path"] = None
4789        pars["data"]["feature_save_path"] = None
4790        pars_data_copy = self._get_data_pars(pars)
4791        saved_data_name = self._saved_datasets().find_name(pars_data_copy)
4792        if saved_data_name is not None:
4793            pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name)
4794            pars["data"]["feature_save_path"] = self._dataset_store_path(
4795                saved_data_name
4796            ).split(".")[0]
4797        else:
4798            dataset_path = self._dataset_store_path(episode_name)
4799            if os.path.exists(dataset_path):
4800                name, ext = dataset_path.split(".")
4801                i = 0
4802                while os.path.exists(f"{name}_{i}.{ext}"):
4803                    i += 1
4804                dataset_path = f"{name}_{i}.{ext}"
4805            pars["data"]["saved_data_path"] = dataset_path
4806            pars["data"]["feature_save_path"] = dataset_path.split(".")[0]
4807        split_split = pars["training"]["partition_method"].split(":")
4808        random = True
4809        for partition_method in options.partition_methods["fixed"]:
4810            method_split = partition_method.split(":")
4811            if len(split_split) != len(method_split):
4812                continue
4813            equal = True
4814            for x, y in zip(split_split, method_split):
4815                if y.startswith("{"):
4816                    continue
4817                if x != y:
4818                    equal = False
4819                    break
4820            if equal:
4821                random = False
4822                break
4823        if random and os.path.exists(pars["training"]["split_path"]):
4824            pars["training"]["partition_method"] = "file"
4825        pars["general"]["save_dataset"] = True
4826        # Check len_segment for c2f models
4827        if pars["general"]["model_name"].startswith("c2f"):
4828            if int(pars["general"]["len_segment"]) < 512:
4829                raise ValueError(
4830                    "The segment length should be higher than 512 when using one of the C2F models"
4831                )
4832        return pars
4833
4834    def _get_data_pars(self, pars: Dict) -> Dict:
4835        """Get a complete description of the data from a general parameters dictionary.
4836
4837        Parameters
4838        ----------
4839        pars : dict
4840            the general parameters dictionary
4841
4842        Returns
4843        -------
4844        pars_data : dict
4845            the complete data parameters dictionary
4846
4847        """
4848        pars_data_copy = deepcopy(pars["data"])
4849        for par in [
4850            "only_load_annotated",
4851            "exclusive",
4852            "feature_extraction",
4853            "ignored_clips",
4854            "len_segment",
4855            "overlap",
4856        ]:
4857            pars_data_copy[par] = pars["general"].get(par, None)
4858        pars_data_copy.update(pars["features"])
4859        return pars_data_copy
4860
4861    def _make_al_points_from_suggestions(
4862        self,
4863        suggestions_name: str,
4864        task: TaskDispatcher,
4865        predicted_classes: Dict,
4866        background_threshold: Optional[float],
4867        visibility_min_score: float,
4868        visibility_min_frac: float,
4869        num_behaviors: int,
4870    ):
4871        valleys = []
4872        if background_threshold is not None:
4873            for i in range(num_behaviors):
4874                print(f"generating background for behavior {i}...")
4875                valleys.append(
4876                    task.dataset("train").find_valleys(
4877                        predicted_classes,
4878                        threshold=background_threshold,
4879                        visibility_min_score=visibility_min_score,
4880                        visibility_min_frac=visibility_min_frac,
4881                        main_class=i,
4882                        low=True,
4883                        cut_annotated=True,
4884                        min_frames=1,
4885                    )
4886                )
4887        valleys = task.dataset("train").valleys_intersection(valleys)
4888        folder = os.path.join(
4889            self.project_path, "results", "suggestions", suggestions_name
4890        )
4891        os.makedirs(os.path.dirname(folder), exist_ok=True)
4892        res = {}
4893        for file in os.listdir(folder):
4894            video_id = file.split("_suggestion.p")[0]
4895            res[video_id] = []
4896            with open(os.path.join(folder, file), "rb") as f:
4897                data = pickle.load(f)
4898            for clip_id, ind_list in zip(data[2], data[3]):
4899                max_len = max(
4900                    [
4901                        max([x[1] for x in cat_list]) if len(cat_list) > 0 else 0
4902                        for cat_list in ind_list
4903                    ]
4904                )
4905                if max_len == 0:
4906                    continue
4907                arr = torch.zeros(max_len)
4908                for cat_list in ind_list:
4909                    for start, end, amb in cat_list:
4910                        arr[start:end] = 1
4911                if video_id in valleys:
4912                    for start, end, clip in valleys[video_id]:
4913                        if clip == clip_id:
4914                            arr[start:end] = 1
4915                output, indices, counts = torch.unique_consecutive(
4916                    arr > 0, return_inverse=True, return_counts=True
4917                )
4918                long_indices = torch.where(output)[0]
4919                res[video_id] += [
4920                    (
4921                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
4922                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
4923                        clip_id,
4924                    )
4925                    for i in long_indices
4926                ]
4927        return res
4928
4929    def _make_al_points(
4930        self,
4931        task: TaskDispatcher,
4932        predicted_error: torch.Tensor,
4933        predicted_classes: torch.Tensor,
4934        exclude_classes: List,
4935        exclude_threshold: List,
4936        exclude_threshold_diff: List,
4937        exclude_hysteresis: List,
4938        include_classes: List,
4939        include_threshold: List,
4940        include_threshold_diff: List,
4941        include_hysteresis: List,
4942        error_episode: str = None,
4943        error_class: str = None,
4944        suggestion_episodes: List = None,
4945        error_threshold: float = 0.5,
4946        error_threshold_diff: float = 0.1,
4947        error_hysteresis: bool = False,
4948        min_frames_al: int = 30,
4949        visibility_min_score: float = 5,
4950        visibility_min_frac: float = 0.7,
4951    ) -> Dict:
4952        """Generate an active learning file."""
4953        if len(exclude_classes) > 0 or len(include_classes) > 0:
4954            valleys = []
4955            included = None
4956            excluded = None
4957            for class_name, thr, thr_diff, hysteresis in zip(
4958                exclude_classes,
4959                exclude_threshold,
4960                exclude_threshold_diff,
4961                exclude_hysteresis,
4962            ):
4963                episode = self._episodes().get_runs(suggestion_episodes[0])[0]
4964                class_index = self._episode(episode).get_class_ind(class_name)
4965                valleys.append(
4966                    task.dataset("train").find_valleys(
4967                        predicted_classes,
4968                        predicted_error=predicted_error,
4969                        min_frames=min_frames_al,
4970                        threshold=thr,
4971                        visibility_min_score=visibility_min_score,
4972                        visibility_min_frac=visibility_min_frac,
4973                        error_threshold=error_threshold,
4974                        main_class=class_index,
4975                        low=True,
4976                        threshold_diff=thr_diff,
4977                        min_frames_error=min_frames_al,
4978                        hysteresis=hysteresis,
4979                    )
4980                )
4981            if len(valleys) > 0:
4982                included = task.dataset("train").valleys_union(valleys)
4983            valleys = []
4984            for class_name, thr, thr_diff, hysteresis in zip(
4985                include_classes,
4986                include_threshold,
4987                include_threshold_diff,
4988                include_hysteresis,
4989            ):
4990                episode = self._episodes().get_runs(suggestion_episodes[0])[0]
4991                class_index = self._episode(episode).get_class_ind(class_name)
4992                valleys.append(
4993                    task.dataset("train").find_valleys(
4994                        predicted_classes,
4995                        predicted_error=predicted_error,
4996                        min_frames=min_frames_al,
4997                        threshold=thr,
4998                        visibility_min_score=visibility_min_score,
4999                        visibility_min_frac=visibility_min_frac,
5000                        error_threshold=error_threshold,
5001                        main_class=class_index,
5002                        low=False,
5003                        threshold_diff=thr_diff,
5004                        min_frames_error=min_frames_al,
5005                        hysteresis=hysteresis,
5006                    )
5007                )
5008            if len(valleys) > 0:
5009                excluded = task.dataset("train").valleys_union(valleys)
5010            al_points = task.dataset("train").valleys_intersection([included, excluded])
5011        else:
5012            class_index = self._episode(error_episode).get_class_ind(error_class)
5013            print("generating active learning intervals...")
5014            al_points = task.dataset("train").find_valleys(
5015                predicted_error,
5016                min_frames=min_frames_al,
5017                threshold=error_threshold,
5018                visibility_min_score=visibility_min_score,
5019                visibility_min_frac=visibility_min_frac,
5020                main_class=class_index,
5021                low=True,
5022                threshold_diff=error_threshold_diff,
5023                min_frames_error=min_frames_al,
5024                hysteresis=error_hysteresis,
5025            )
5026        for v_id in al_points:
5027            clip_dict = defaultdict(lambda: [])
5028            res = []
5029            for x in al_points[v_id]:
5030                clip_dict[x[-1]].append(x)
5031            for clip_id in clip_dict:
5032                clip_dict[clip_id] = sorted(clip_dict[clip_id])
5033                i = 0
5034                j = 1
5035                while j < len(clip_dict[clip_id]):
5036                    end = clip_dict[clip_id][i][1]
5037                    start = clip_dict[clip_id][j][0]
5038                    if start - end < 30:
5039                        clip_dict[clip_id][i][1] = clip_dict[clip_id][j][1]
5040                    else:
5041                        res.append(clip_dict[clip_id][i])
5042                        i = j
5043                    j += 1
5044                res.append(clip_dict[clip_id][i])
5045            al_points[v_id] = sorted(res)
5046        return al_points
5047
5048    def _make_suggestions(
5049        self,
5050        task: TaskDispatcher,
5051        predicted_error: torch.Tensor,
5052        predicted_classes: torch.Tensor,
5053        suggestion_threshold: List,
5054        suggestion_threshold_diff: List,
5055        suggestion_hysteresis: List,
5056        suggestion_episodes: List = None,
5057        suggestion_classes: List = None,
5058        error_threshold: float = 0.5,
5059        min_frames_suggestion: int = 3,
5060        min_frames_al: int = 30,
5061        visibility_min_score: float = 0,
5062        visibility_min_frac: float = 0.7,
5063        cut_annotated: bool = False,
5064    ) -> Dict:
5065        """Make a suggestions dictionary."""
5066        suggestions = defaultdict(lambda: {})
5067        for class_name, thr, thr_diff, hysteresis in zip(
5068            suggestion_classes,
5069            suggestion_threshold,
5070            suggestion_threshold_diff,
5071            suggestion_hysteresis,
5072        ):
5073            episode = self._episodes().get_runs(suggestion_episodes[0])[0]
5074            class_index = self._episode(episode).get_class_ind(class_name)
5075            print(f"generating suggestions for {class_name}...")
5076            found = task.dataset("train").find_valleys(
5077                predicted_classes,
5078                smooth_interval=2,
5079                predicted_error=predicted_error,
5080                min_frames=min_frames_suggestion,
5081                threshold=thr,
5082                visibility_min_score=visibility_min_score,
5083                visibility_min_frac=visibility_min_frac,
5084                error_threshold=error_threshold,
5085                main_class=class_index,
5086                low=False,
5087                threshold_diff=thr_diff,
5088                min_frames_error=min_frames_al,
5089                hysteresis=hysteresis,
5090                cut_annotated=cut_annotated,
5091            )
5092            for v_id in found:
5093                suggestions[v_id][class_name] = found[v_id]
5094        suggestions = dict(suggestions)
5095        return suggestions
5096
5097    def count_classes(
5098        self,
5099        load_episode: str = None,
5100        parameters_update: Dict = None,
5101        remove_saved_features: bool = False,
5102        bouts: bool = True,
5103    ) -> Dict:
5104        """Get a dictionary of class counts in different modes.
5105
5106        Parameters
5107        ----------
5108        load_episode : str, optional
5109            the episode settings to load
5110        parameters_update : dict, optional
5111            a dictionary of parameter updates (only for "data" and "general" categories)
5112        remove_saved_features : bool, default False
5113            if `True`, the dataset that is used for computation is then deleted
5114        bouts : bool, default False
5115            if `True`, instead of frame counts segment counts are returned
5116
5117        Returns
5118        -------
5119        class_counts : dict
5120            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
5121            class names and values are class counts (in frames)
5122
5123        """
5124        if load_episode is None:
5125            task, parameters = self._make_task_training(
5126                episode_name="_", parameters_update=parameters_update, throwaway=True
5127            )
5128        else:
5129            task, parameters, _ = self._make_task_prediction(
5130                "_",
5131                load_episode=load_episode,
5132                parameters_update=parameters_update,
5133            )
5134        class_counts = task.count_classes(bouts=bouts)
5135        behaviors = task.behaviors_dict()
5136        class_counts = {
5137            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
5138            for kk, vv in class_counts.items()
5139        }
5140        if remove_saved_features:
5141            self._remove_stores(parameters)
5142        return class_counts
5143
5144    def plot_class_distribution(
5145        self,
5146        parameters_update: Dict = None,
5147        frame_cutoff: int = 1,
5148        bout_cutoff: int = 1,
5149        print_full: bool = False,
5150        remove_saved_features: bool = False,
5151        save: str = None,
5152    ) -> None:
5153        """Make a class distribution plot.
5154
5155        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
5156        is created or loaded for the computation with the default parameters).
5157
5158        Parameters
5159        ----------
5160        parameters_update : dict, optional
5161            a dictionary of parameter updates (only for "data" and "general" categories)
5162        frame_cutoff : int, default 1
5163            the minimum number of frames for a segment to be considered
5164        bout_cutoff : int, default 1
5165            the minimum number of bouts for a class to be considered
5166        print_full : bool, default False
5167            if `True`, the full class distribution is printed
5168        remove_saved_features : bool, default False
5169            if `True`, the dataset that is used for computation is then deleted
5170
5171        """
5172        task, parameters = self._make_task_training(
5173            episode_name="_", parameters_update=parameters_update, throwaway=True
5174        )
5175        cutoff = {True: bout_cutoff, False: frame_cutoff}
5176        for bouts in [True, False]:
5177            class_counts = task.count_classes(bouts=bouts)
5178            if print_full:
5179                print("Bouts:" if bouts else "Frames:")
5180                for k, v in class_counts.items():
5181                    if sum(v.values()) != 0:
5182                        print(f"  {k}:")
5183                        values, keys = zip(
5184                            *[
5185                                x
5186                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
5187                                if x[-1] != -100
5188                            ]
5189                        )
5190                        for kk, vv in zip(keys, values):
5191                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
5192            class_counts = {
5193                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
5194                for kk, vv in class_counts.items()
5195            }
5196            for key, d in class_counts.items():
5197                if sum(d.values()) != 0:
5198                    values, keys = zip(
5199                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
5200                    )
5201                    keys = [task.behaviors_dict()[x] for x in keys]
5202                    plt.bar(keys, values)
5203                    plt.title(key)
5204                    plt.xticks(rotation=45, ha="right")
5205                    if bouts:
5206                        plt.ylabel("bouts")
5207                    else:
5208                        plt.ylabel("frames")
5209                    plt.tight_layout()
5210
5211                    if save is None:
5212                        plt.savefig(save)
5213                        plt.close()
5214                    else:
5215                        plt.show()
5216        if remove_saved_features:
5217            self._remove_stores(parameters)
5218
5219    def _generate_mask(
5220        self,
5221        mask_name: str,
5222        perc_annotated: float = 0.1,
5223        parameters_update: Dict = None,
5224        remove_saved_features: bool = False,
5225    ) -> None:
5226        """Generate a real_lens for active learning simulation.
5227
5228        Parameters
5229        ----------
5230        mask_name : str
5231            the name of the real_lens
5232        perc_annotated : float, default 0.1
5233            a
5234
5235        """
5236        print(f"GENERATING {mask_name}")
5237        task, parameters = self._make_task_training(
5238            f"_{mask_name}", parameters_update=parameters_update, throwaway=True
5239        )
5240        val_intervals, val_ids = task.dataset("val").get_intervals()  # 1
5241        unannotated_intervals = task.dataset("train").get_unannotated_intervals()  # 2
5242        unannotated_intervals = task.dataset("val").get_unannotated_intervals(
5243            first_intervals=unannotated_intervals
5244        )
5245        ids = task.dataset("train").get_ids()
5246        mask = {video_id: {} for video_id in ids}
5247        total_all = 0
5248        total_masked = 0
5249        for video_id, clip_ids in ids.items():
5250            for clip_id in clip_ids:
5251                frames = np.ones(task.dataset("train").get_len(video_id, clip_id))
5252                if clip_id in val_intervals[video_id]:
5253                    for start, end in val_intervals[video_id][clip_id]:
5254                        frames[start:end] = 0
5255                if clip_id in unannotated_intervals[video_id]:
5256                    for start, end in unannotated_intervals[video_id][clip_id]:
5257                        frames[start:end] = 0
5258                annotated = np.where(frames)[0]
5259                total_all += len(annotated)
5260                masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :]
5261                total_masked += len(masked)
5262                mask[video_id][clip_id] = self._get_intervals(masked)
5263        file = {
5264            "masked": mask,
5265            "val_intervals": val_intervals,
5266            "val_ids": val_ids,
5267            "unannotated": unannotated_intervals,
5268        }
5269        self._save_mask(file, mask_name)
5270        if remove_saved_features:
5271            self._remove_stores(parameters)
5272        print("\n")
5273        # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames')
5274
5275    def _get_intervals(self, frame_indices: np.ndarray):
5276        """Get a list of intervals from a list of frame indices.
5277
5278        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
5279
5280        Parameters
5281        ----------
5282        frame_indices : np.ndarray
5283            a list of frame indices
5284
5285        Returns
5286        -------
5287        intervals : list
5288            a list of interval boundaries
5289
5290        """
5291        masked_intervals = []
5292        if len(frame_indices) > 0:
5293            breaks = np.where(np.diff(frame_indices) != 1)[0]
5294            start = frame_indices[0]
5295            for k in breaks:
5296                masked_intervals.append([start, frame_indices[k] + 1])
5297                start = frame_indices[k + 1]
5298            masked_intervals.append([start, frame_indices[-1] + 1])
5299        return masked_intervals
5300
5301    def _update_mask_with_uncertainty(
5302        self,
5303        mask_name: str,
5304        episode_name: Union[str, None],
5305        classes: List,
5306        load_epoch: int = None,
5307        n_frames: int = 10000,
5308        method: str = "least_confidence",
5309        min_length: int = 30,
5310        augment_n: int = 0,
5311        parameters_update: Dict = None,
5312    ):
5313        """Update real_lens with frame-wise uncertainty scores for active learning.
5314
5315        Parameters
5316        ----------
5317        mask_name : str
5318            the name of the real_lens
5319        episode_name : str
5320            the name of the episode to load
5321        classes : list
5322            a list of class names or indices; their uncertainty scores will be computed separately and stacked
5323        load_epoch : int, optional
5324            the epoch to load (by default last; if this epoch is not saved the closest checkpoint is chosen)
5325        n_frames : int, default 10000
5326            the number of frames to "annotate"
5327        method : {"least_confidence", "entropy"}
5328            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
5329            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
5330        min_length : int
5331            the minimum length (in frames) of the annotated intervals
5332        augment_n : int, default 0
5333            the number of augmentations to average over
5334        parameters_update : dict, optional
5335            the dictionary used to update the parameters from the config
5336
5337        Returns
5338        -------
5339        score_dicts : dict
5340            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
5341            are score tensors
5342
5343        """
5344        print(f"UPDATING {mask_name}")
5345        task, parameters, _ = self._make_task_prediction(
5346            prediction_name=mask_name,
5347            load_episode=episode_name,
5348            parameters_update=parameters_update,
5349            load_epoch=load_epoch,
5350            mode="train",
5351        )
5352        score_tensors = task.generate_uncertainty_score(classes, augment_n, method)
5353        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
5354        print("\n")
5355
5356    def _update_mask_with_BALD(
5357        self,
5358        mask_name: str,
5359        episode_name: str,
5360        classes: List,
5361        load_epoch: int = None,
5362        augment_n: int = 0,
5363        n_frames: int = 10000,
5364        num_models: int = 10,
5365        kernel_size: int = 11,
5366        min_length: int = 30,
5367        parameters_update: Dict = None,
5368    ):
5369        """Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning.
5370
5371        Parameters
5372        ----------
5373        mask_name : str
5374            the name of the real_lens
5375        episode_name : str
5376            the name of the episode to load
5377        classes : list
5378            a list of class names or indices; their uncertainty scores will be computed separately and stacked
5379        load_epoch : int, optional
5380            the epoch to load (by default last)
5381        augment_n : int, default 0
5382            the number of augmentations to average over
5383        n_frames : int, default 10000
5384            the number of frames to "annotate"
5385        num_models : int, default 10
5386            the number of dropout masks to apply
5387        kernel_size : int, default 11
5388            the size of the smoothing gaussian kernel
5389        min_length : int
5390            the minimum length (in frames) of the annotated intervals
5391        parameters_update : dict, optional
5392            the dictionary used to update the parameters from the config
5393
5394        Returns
5395        -------
5396        score_dicts : dict
5397            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
5398            are score tensors
5399
5400        """
5401        print(f"UPDATING {mask_name}")
5402        task, parameters, mode = self._make_task_prediction(
5403            mask_name,
5404            load_episode=episode_name,
5405            parameters_update=parameters_update,
5406            load_epoch=load_epoch,
5407        )
5408        score_tensors = task.generate_bald_score(
5409            classes, augment_n, num_models, kernel_size
5410        )
5411        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
5412        print("\n")
5413
5414    def _suggest_intervals(
5415        self,
5416        dataset: BehaviorDataset,
5417        score_tensors: Dict,
5418        n_frames: int,
5419        min_length: int,
5420    ) -> Dict:
5421        """Suggest intervals with highest score of total length `n_frames`.
5422
5423        Parameters
5424        ----------
5425        dataset : BehaviorDataset
5426            the dataset
5427        score_tensors : dict
5428            a dictionary where keys are clip ids and values are framewise score tensors
5429        n_frames : int
5430            the number of frames to "annotate"
5431        min_length : int
5432            minimum length of suggested intervals
5433
5434        Returns
5435        -------
5436        active_learning_intervals : Dict
5437            active learning dictionary with suggested intervals
5438
5439        """
5440        video_intervals, _ = dataset.get_intervals()
5441        taken = {
5442            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
5443        }
5444        annotated = dataset.get_annotated_intervals()
5445        for video_id in video_intervals:
5446            for clip_id in video_intervals[video_id]:
5447                taken[video_id][clip_id] = torch.zeros(
5448                    dataset.get_len(video_id, clip_id)
5449                )
5450                if video_id in annotated and clip_id in annotated[video_id]:
5451                    for start, end in annotated[video_id][clip_id]:
5452                        score_tensors[video_id][clip_id][:, start:end] = -10
5453                        taken[video_id][clip_id][int(start) : int(end)] = 1
5454        n_frames = (
5455            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
5456            + n_frames
5457        )
5458        factor = 1
5459        threshold_start = float(
5460            torch.mean(
5461                torch.tensor(
5462                    [
5463                        torch.mean(
5464                            torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
5465                        )
5466                        for x in score_tensors.values()
5467                    ]
5468                )
5469            )
5470        )
5471        while (
5472            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
5473            < n_frames
5474        ):
5475            threshold = threshold_start * factor
5476            intervals = []
5477            interval_scores = []
5478            key1 = list(score_tensors.keys())[0]
5479            key2 = list(score_tensors[key1].keys())[0]
5480            num_scores = score_tensors[key1][key2].shape[0]
5481            for i in range(num_scores):
5482                v_dict = dataset.find_valleys(
5483                    predicted=score_tensors,
5484                    threshold=threshold,
5485                    min_frames=min_length,
5486                    main_class=i,
5487                    low=False,
5488                )
5489                for v_id, interval_list in v_dict.items():
5490                    intervals += [x + [v_id] for x in interval_list]
5491                    interval_scores += [
5492                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
5493                        for start, end, clip_id in interval_list
5494                    ]
5495            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
5496            i = 0
5497            while sum(
5498                [(vv == 1).sum() for v in taken.values() for vv in v.values()]
5499            ) < n_frames and i < len(intervals):
5500                start, end, clip_id, video_id = intervals[i]
5501                i += 1
5502                taken[video_id][clip_id][int(start) : int(end)] = 1
5503            factor *= 0.9
5504            if factor < 0.05:
5505                warnings.warn(f"Could not find enough frames!")
5506                break
5507        active_learning_intervals = {video_id: [] for video_id in video_intervals}
5508        for video_id in taken:
5509            for clip_id in taken[video_id]:
5510                if video_id in annotated and clip_id in annotated[video_id]:
5511                    for start, end in annotated[video_id][clip_id]:
5512                        taken[video_id][clip_id][int(start) : int(end)] = 0
5513                if (taken[video_id][clip_id] == 1).sum() == 0:
5514                    continue
5515                indices = np.where(taken[video_id][clip_id].numpy())[0]
5516                boundaries = self._get_intervals(indices)
5517                active_learning_intervals[video_id] += [
5518                    [start, end, clip_id] for start, end in boundaries
5519                ]
5520        return active_learning_intervals
5521
5522    def _update_mask(
5523        self,
5524        task: TaskDispatcher,
5525        mask_name: str,
5526        score_tensors: Dict,
5527        n_frames: int,
5528        min_length: int,
5529    ) -> None:
5530        """Update the real_lens with intervals with the highest score of total length `n_frames`.
5531
5532        Parameters
5533        ----------
5534        task : TaskDispatcher
5535            the task dispatcher object
5536        mask_name : str
5537            the name of the real_lens
5538        score_tensors : dict
5539            a dictionary where keys are clip ids and values are framewise score tensors
5540        n_frames : int
5541            the number of frames to "annotate"
5542        min_length : int
5543            the minimum length of the annotated intervals
5544
5545        """
5546        mask = self._load_mask(mask_name)
5547        video_intervals, _ = task.dataset("train").get_intervals()
5548        masked = {
5549            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
5550        }
5551        total_masked = 0
5552        total_all = 0
5553        for video_id in video_intervals:
5554            for clip_id in video_intervals[video_id]:
5555                masked[video_id][clip_id] = torch.zeros(
5556                    task.dataset("train").get_len(video_id, clip_id)
5557                )
5558                if (
5559                    video_id in mask["unannotated"]
5560                    and clip_id in mask["unannotated"][video_id]
5561                ):
5562                    for start, end in mask["unannotated"][video_id][clip_id]:
5563                        score_tensors[video_id][clip_id][:, start:end] = -10
5564                        masked[video_id][clip_id][int(start) : int(end)] = 1
5565                if (
5566                    video_id in mask["val_intervals"]
5567                    and clip_id in mask["val_intervals"][video_id]
5568                ):
5569                    for start, end in mask["val_intervals"][video_id][clip_id]:
5570                        score_tensors[video_id][clip_id][:, start:end] = -10
5571                        masked[video_id][clip_id][int(start) : int(end)] = 1
5572                total_all += torch.sum(masked[video_id][clip_id] == 0)
5573                if video_id in mask["masked"] and clip_id in mask["masked"][video_id]:
5574                    # print(f'{real_lens["masked"][video_id][clip_id]=}')
5575                    for start, end in mask["masked"][video_id][clip_id]:
5576                        masked[video_id][clip_id][int(start) : int(end)] = 1
5577                        total_masked += end - start
5578        old_n_frames = sum(
5579            [(vv == 0).sum() for v in masked.values() for vv in v.values()]
5580        )
5581        n_frames = old_n_frames + n_frames
5582        factor = 1
5583        while (
5584            sum([(vv == 0).sum() for v in masked.values() for vv in v.values()])
5585            < n_frames
5586        ):
5587            threshold = float(
5588                torch.mean(
5589                    torch.tensor(
5590                        [
5591                            torch.mean(
5592                                torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
5593                            )
5594                            for x in score_tensors.values()
5595                        ]
5596                    )
5597                )
5598            )
5599            threshold = threshold * factor
5600            intervals = []
5601            interval_scores = []
5602            key1 = list(score_tensors.keys())[0]
5603            key2 = list(score_tensors[key1].keys())[0]
5604            num_scores = score_tensors[key1][key2].shape[0]
5605            for i in range(num_scores):
5606                v_dict = task.dataset("train").find_valleys(
5607                    predicted=score_tensors,
5608                    threshold=threshold,
5609                    min_frames=min_length,
5610                    main_class=i,
5611                    low=False,
5612                )
5613                for v_id, interval_list in v_dict.items():
5614                    intervals += [x + [v_id] for x in interval_list]
5615                    interval_scores += [
5616                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
5617                        for start, end, clip_id in interval_list
5618                    ]
5619            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
5620            i = 0
5621            while sum(
5622                [(vv == 0).sum() for v in masked.values() for vv in v.values()]
5623            ) < n_frames and i < len(intervals):
5624                start, end, clip_id, video_id = intervals[i]
5625                i += 1
5626                masked[video_id][clip_id][int(start) : int(end)] = 0
5627            factor *= 0.9
5628            if factor < 0.05:
5629                warnings.warn(f"Could not find enough frames!")
5630                break
5631        mask["masked"] = {video_id: {} for video_id in video_intervals}
5632        total_masked_new = 0
5633        for video_id in masked:
5634            for clip_id in masked[video_id]:
5635                if (
5636                    video_id in mask["unannotated"]
5637                    and clip_id in mask["unannotated"][video_id]
5638                ):
5639                    for start, end in mask["unannotated"][video_id][clip_id]:
5640                        masked[video_id][clip_id][int(start) : int(end)] = 0
5641                if (
5642                    video_id in mask["val_intervals"]
5643                    and clip_id in mask["val_intervals"][video_id]
5644                ):
5645                    for start, end in mask["val_intervals"][video_id][clip_id]:
5646                        masked[video_id][clip_id][int(start) : int(end)] = 0
5647                indices = np.where(masked[video_id][clip_id].numpy())[0]
5648                mask["masked"][video_id][clip_id] = self._get_intervals(indices)
5649        for video_id in mask["masked"]:
5650            for clip_id in mask["masked"][video_id]:
5651                for start, end in mask["masked"][video_id][clip_id]:
5652                    total_masked_new += end - start
5653        self._save_mask(mask, mask_name)
5654        with open(
5655            os.path.join(
5656                self.project_path, "results", f"{mask_name}.txt", encoding="utf-8"
5657            ),
5658            "a",
5659        ) as f:
5660            f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n")
5661        print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}")
5662
5663    def _visualize_results_label(
5664        self,
5665        episode_name: str,
5666        label: str,
5667        load_epoch: int = None,
5668        parameters_update: Dict = None,
5669        add_legend: bool = True,
5670        ground_truth: bool = True,
5671        hide_axes: bool = False,
5672        width: float = 10,
5673        whole_video: bool = False,
5674        transparent: bool = False,
5675        num_plots: int = 1,
5676        smooth_interval: int = 0,
5677    ):
5678        other_path = os.path.join(self.project_path, "results", "other")
5679        if not os.path.exists(other_path):
5680            os.mkdir(other_path)
5681        if parameters_update is None:
5682            parameters_update = {}
5683        if "model" in parameters_update.keys():
5684            raise ValueError("Cannot change model parameters after training!")
5685        task, parameters, _ = self._make_task_prediction(
5686            "_",
5687            load_episode=episode_name,
5688            parameters_update=parameters_update,
5689            load_epoch=load_epoch,
5690            mode="val",
5691        )
5692        for i in range(num_plots):
5693            print(i)
5694            task._visualize_results_label(
5695                smooth_interval=smooth_interval,
5696                label=label,
5697                save_path=os.path.join(
5698                    other_path, f"{episode_name}_prediction_{i}.jpg"
5699                ),
5700                add_legend=add_legend,
5701                ground_truth=ground_truth,
5702                hide_axes=hide_axes,
5703                whole_video=whole_video,
5704                transparent=transparent,
5705                dataset="val",
5706                width=width,
5707                title=str(i),
5708            )
5709
5710    def plot_confusion_matrix(
5711        self,
5712        episode_name: str,
5713        load_epoch: int = None,
5714        parameters_update: Dict = None,
5715        metric: str = "recall",
5716        mode: str = "val",
5717        remove_saved_features: bool = False,
5718        save_path: str = None,
5719        cmap: str = "viridis",
5720    ) -> Tuple[ndarray, Iterable]:
5721        """Make a confusion matrix plot and return the data.
5722
5723        If the annotation is non-exclusive, only false positive labels are considered.
5724
5725        Parameters
5726        ----------
5727        episode_name : str
5728            the name of the episode to load
5729        load_epoch : int, optional
5730            the index of the epoch to load (by default the last one is loaded)
5731        parameters_update : dict, optional
5732            a dictionary of parameter updates (only for "data" and "general" categories)
5733        metric : {"recall", "precision"}
5734            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
5735            into account, and if `type` is `"precision"`, only false negatives
5736        mode : {'val', 'all', 'test', 'train'}
5737            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
5738        remove_saved_features : bool, default False
5739            if `True`, the dataset that is used for computation is then deleted
5740
5741        Returns
5742        -------
5743        confusion_matrix : np.ndarray
5744            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
5745            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
5746            `N_i` is the number of frames that have the i-th label in the ground truth
5747        classes : list
5748            a list of labels
5749
5750        """
5751        task, parameters, mode = self._make_task_prediction(
5752            "_",
5753            load_episode=episode_name,
5754            load_epoch=load_epoch,
5755            parameters_update=parameters_update,
5756            mode=mode,
5757        )
5758        dataset = task.dataset(mode)
5759        prediction = task.predict(dataset, raw_output=True)
5760        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
5761        if remove_saved_features:
5762            self._remove_stores(parameters)
5763        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
5764        ax.imshow(confusion_matrix, cmap=cmap)
5765        # Show all ticks and label them with the respective list entries
5766        ax.set_xticks(np.arange(len(classes)))
5767        ax.set_xticklabels(classes)
5768        ax.set_yticks(np.arange(len(classes)))
5769        ax.set_yticklabels(classes)
5770        # Rotate the tick labels and set their alignment.
5771        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
5772        # Loop over data dimensions and create text annotations.
5773        for i in range(len(classes)):
5774            for j in range(len(classes)):
5775                ax.text(
5776                    j,
5777                    i,
5778                    np.round(confusion_matrix[i, j], 2),
5779                    ha="center",
5780                    va="center",
5781                    color="w",
5782                )
5783        if metric is not None:
5784            ax.set_title(f"{metric} {episode_name}")
5785        else:
5786            ax.set_title(episode_name)
5787        fig.tight_layout()
5788        if save_path is None:
5789            plt.show()
5790        else:
5791            plt.savefig(save_path)
5792            plt.close()
5793        return confusion_matrix, classes
5794
5795    def _plot_ethograms_gt_pred(
5796        self,
5797        data_gt: dict,
5798        data_pred: dict,
5799        labels_gt: list,
5800        labels_pred: list,
5801        start: int = 0,
5802        end: int = -1,
5803        cmap_pred: str = "binary",
5804        cmap_gt: str = "binary",
5805        save: str = None,
5806        fontsize=22,
5807        time_mode="frames",
5808        fps: int = None,
5809    ) -> None:
5810        """Plot ethograms from start to end time (in frames), mode can be prediction or ground truth depending on the data format."""
5811        # print(data.keys())
5812        best_pred = (
5813            data_pred[list(data_pred.keys())[0]].numpy() > 0.5
5814        )  # Threshold the predictions
5815        data_gt = binarize_data(data_gt, max_frame=end)
5816
5817        # Crop data to min length
5818        if end < 0:
5819            end = min(data_gt.shape[1], best_pred.shape[1])
5820        data_gt = data_gt[:, :end]
5821        best_pred = best_pred[:, :end]
5822
5823        # Reorder behaviors
5824        ind_gt = []
5825        ind_pred = []
5826        labels_pred = [labels_pred[i] for i in range(len(labels_pred))]
5827        labels_pred = np.roll(
5828            labels_pred, 1
5829        ).tolist()  
5830        check_gt = np.where(np.sum(data_gt, axis=1) > 0)[0]
5831        check_pred = np.where(np.sum(best_pred, axis=1) > 0)[0]
5832        for k, gt_beh in enumerate(labels_gt):
5833            if gt_beh in labels_pred:
5834                j = labels_pred.index(gt_beh)
5835                if not k in check_gt and not j in check_pred:
5836                    continue
5837                ind_gt.append(labels_gt.index(gt_beh))
5838                ind_pred.append(j)
5839        # Create label list
5840        labels = np.array(labels_gt)[ind_gt]
5841        assert (labels == np.array(labels_pred)[ind_pred]).all()
5842
5843        # # Create image
5844        image_pred = best_pred[ind_pred].astype(float)
5845        image_gt = data_gt[ind_gt]
5846
5847        f, axs = plt.subplots(
5848            len(labels), 1, figsize=(5 * len(labels), 15), sharex=True
5849        )
5850        end = image_gt.shape[1] if end < 0 else end
5851        for i, (ax, label) in enumerate(zip(axs, labels)):
5852
5853            im1 = np.array([image_gt[i], np.ones_like(image_gt[i]) * (-1)])
5854            im1 = np.ma.masked_array(im1, im1 < 0)
5855
5856            im2 = np.array([np.ones_like(image_pred[i]) * (-1), image_pred[i]])
5857            im2 = np.ma.masked_array(im2, im2 < 0)
5858
5859            ax.imshow(im1, aspect="auto", cmap=cmap_gt, interpolation="nearest")
5860            ax.imshow(im2, aspect="auto", cmap=cmap_pred, interpolation="nearest")
5861
5862            ax.set_yticks(np.arange(2), ["GT", "Pred"], fontsize=fontsize)
5863            ax.tick_params(axis="x", labelsize=fontsize)
5864            ax.set_ylabel(label, fontsize=fontsize)
5865            if time_mode == "frames":
5866                ax.set_xlabel("Num Frames", fontsize=fontsize)
5867            elif time_mode == "seconds":
5868                assert not fps is None, "Please provide fps"
5869                ax.set_xlabel("Time (s)", fontsize=fontsize)
5870                ax.set_xticks(
5871                    np.linspace(0, end, 10),
5872                    np.linspace(0, end / fps, 10).astype(np.int32),
5873                )
5874
5875            ax.set_xlim(start, end)
5876
5877        if save is None:
5878            plt.show()
5879        else:
5880            plt.savefig(save)
5881            plt.close()
5882
5883    def plot_ethograms(
5884        self,
5885        episode_name: str,
5886        prediction_name: str,
5887        start: int = 0,
5888        end: int = -1,
5889        save_path: str = None,
5890        cmap_pred: str = "binary",
5891        cmap_gt: str = "binary",
5892        fontsize: int = 22,
5893        time_mode: str = "frames",
5894        fps: int = None,
5895    ):
5896        """Plot ethograms from start to end time (in frames) for ground truth and prediction"""
5897        params = self._read_parameters(catch_blanks=False)
5898        parameters = self._get_data_pars(
5899            params,
5900        )
5901        if not save_path is None:
5902            os.makedirs(save_path, exist_ok=True)
5903        gt_files = [
5904            f for f in self.data_path if f.endswith(parameters["annotation_suffix"])
5905        ]
5906        pred_path = os.path.join(
5907            self.project_path, "results", "predictions", prediction_name
5908        )
5909        pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)]
5910        for pred_path in pred_paths:
5911            predictions = load_pickle(pred_path)
5912            behaviors = self.get_behavior_dictionary(episode_name)
5913            gt_filename = os.path.basename(pred_path).replace(
5914                "_".join(["_" + prediction_name, "prediction.pickle"]),
5915                parameters["annotation_suffix"],
5916            )
5917            if os.path.exists(os.path.join(self.data_path, gt_filename)):
5918                gt_data = load_pickle(os.path.join(self.data_path, gt_filename))
5919
5920                self._plot_ethograms_gt_pred(
5921                    gt_data,
5922                    predictions,
5923                    gt_data[1],
5924                    behaviors,
5925                    start=start,
5926                    end=end,
5927                    save=os.path.join(
5928                        save_path,
5929                        os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred",
5930                    ),
5931                    cmap_pred=cmap_pred,
5932                    cmap_gt=cmap_gt,
5933                    fontsize=fontsize,
5934                    time_mode=time_mode,
5935                    fps=fps,
5936                )
5937            else:
5938                print("GT file not found")
5939
5940    def _create_side_panel(self, height, width, labels_pred, preds, labels_gt, gt=None):
5941        """Create a side panel for video annotation display.
5942
5943        Parameters
5944        ----------
5945        height : int
5946            the height of the panel
5947        width : int
5948            the width of the panel
5949        labels_pred : list
5950            the list of predicted behavior labels
5951        preds : array-like
5952            the prediction values for each behavior
5953        labels_gt : list
5954            the list of ground truth behavior labels
5955        gt : array-like, optional
5956            the ground truth values for each behavior
5957
5958        Returns
5959        -------
5960        side_panel : np.ndarray
5961            the created side panel as an image array
5962
5963        """
5964        side_panel = np.ones((height, int(width / 4), 3), dtype=np.uint8) * 255
5965
5966        beh_indices = np.where(preds)[0]
5967        for i, label in enumerate(labels_pred):
5968            color = (0, 0, 0)
5969            if i in beh_indices:
5970                color = (0, 255, 0)
5971            cv2.putText(
5972                side_panel,
5973                label,
5974                (10, 50 + 50 * i),
5975                cv2.FONT_HERSHEY_SIMPLEX,
5976                1,
5977                color,
5978                2,
5979                cv2.LINE_AA,
5980            )
5981        if gt is not None:
5982            beh_indices_gt = np.where(gt)[0]
5983            for i, label in enumerate(labels_gt):
5984                color = (0, 0, 0)
5985                if i in beh_indices_gt:
5986                    color = (0, 255, 0)
5987                cv2.putText(
5988                    side_panel,
5989                    label,
5990                    (10, 50 + 50 * i + 80 * len(labels_pred)),
5991                    cv2.FONT_HERSHEY_SIMPLEX,
5992                    1,
5993                    color,
5994                    2,
5995                    cv2.LINE_AA,
5996                )
5997        return side_panel
5998
5999    def create_annotated_video(
6000        self,
6001        prediction_file_paths: list,
6002        video_file_paths: list,
6003        episode_name: str,  # To get the list of behaviors
6004        ground_truth_file_paths: list = None,
6005        pred_thresh: float = 0.5,
6006        start: int = 0,
6007        end: int = -1,
6008    ):
6009        """Create a video with the predictions overlaid on the video"""
6010        for k, (pred_path, vid_path) in enumerate(
6011            zip(prediction_file_paths, video_file_paths)
6012        ):
6013            print("Generating video for :", os.path.basename(vid_path))
6014            predictions = load_pickle(pred_path)
6015            best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh
6016            behaviors = self.get_behavior_dictionary(episode_name)
6017            # Load video
6018            labels_pred = [behaviors[i] for i in range(len(behaviors))]
6019            labels_pred = np.roll(
6020                labels_pred, 1
6021            ).tolist() 
6022
6023            gt_data = None
6024            if ground_truth_file_paths is not None:
6025                gt_data = load_pickle(ground_truth_file_paths[k])
6026                labels_gt = gt_data[1]
6027                gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1])
6028
6029            cap = cv2.VideoCapture(vid_path)
6030            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6031            end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end
6032            fps = cap.get(cv2.CAP_PROP_FPS)
6033            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
6034            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
6035            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
6036            out = cv2.VideoWriter(
6037                os.path.join(
6038                    os.path.dirname(vid_path),
6039                    os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4",
6040                ),
6041                fourcc,
6042                fps,
6043                # (width + int(width/4) , height),
6044                (600, 300),
6045            )
6046            count = 0
6047            bar = tqdm(total=end - start)
6048            while cap.isOpened():
6049                ret, frame = cap.read()
6050                if not ret:
6051                    break
6052
6053                side_panel = self._create_side_panel(
6054                    height,
6055                    width,
6056                    labels_pred,
6057                    best_pred[:, count],
6058                    labels_gt,
6059                    gt_data[:, count],
6060                )
6061                frame = np.concatenate((frame, side_panel), axis=1)
6062                frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25)
6063                out.write(frame)
6064                count += 1
6065                bar.update(1)
6066
6067                if count > end:
6068                    break
6069
6070            cap.release()
6071            out.release()
6072            cv2.destroyAllWindows()
6073
6074    def plot_predictions(
6075        self,
6076        episode_name: str,
6077        load_epoch: int = None,
6078        parameters_update: Dict = None,
6079        add_legend: bool = True,
6080        ground_truth: bool = True,
6081        colormap: str = "dlc2action",
6082        hide_axes: bool = False,
6083        min_classes: int = 1,
6084        width: float = 10,
6085        whole_video: bool = False,
6086        transparent: bool = False,
6087        drop_classes: Set = None,
6088        search_classes: Set = None,
6089        num_plots: int = 1,
6090        remove_saved_features: bool = False,
6091        smooth_interval_prediction: int = 0,
6092        data_path: str = None,
6093        file_paths: Set = None,
6094        mode: str = "val",
6095        font_size: float = None,
6096        window_size: int = 400,
6097    ) -> None:
6098        """Visualize random predictions.
6099
6100        Parameters
6101        ----------
6102        episode_name : str
6103            the name of the episode to load
6104        load_epoch : int, optional
6105            the epoch to load (by default last)
6106        parameters_update : dict, optional
6107            parameter update dictionary
6108        add_legend : bool, default True
6109            if True, legend will be added to the plot
6110        ground_truth : bool, default True
6111            if True, ground truth will be added to the plot
6112        colormap : str, default 'Accent'
6113            the `matplotlib` colormap to use
6114        hide_axes : bool, default True
6115            if `True`, the axes will be hidden on the plot
6116        min_classes : int, default 1
6117            the minimum number of classes in a displayed interval
6118        width : float, default 10
6119            the width of the plot
6120        whole_video : bool, default False
6121            if `True`, whole videos are plotted instead of segments
6122        transparent : bool, default False
6123            if `True`, the background on the plot is transparent
6124        drop_classes : set, optional
6125            a set of class names to not be displayed
6126        search_classes : set, optional
6127            if given, only intervals where at least one of the classes is in ground truth will be shown
6128        num_plots : int, default 1
6129            the number of plots to make
6130        remove_saved_features : bool, default False
6131            if `True`, the dataset will be deleted after computation
6132        smooth_interval_prediction : int, default 0
6133            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
6134        data_path : str, optional
6135            the data path to run the prediction for
6136        file_paths : set, optional
6137            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
6138            for
6139        mode : {'all', 'test', 'val', 'train'}
6140            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
6141
6142        """
6143        plot_path = os.path.join(self.project_path, "results", "plots")
6144        task, parameters, mode = self._make_task_prediction(
6145            "_",
6146            load_episode=episode_name,
6147            parameters_update=parameters_update,
6148            load_epoch=load_epoch,
6149            data_path=data_path,
6150            file_paths=file_paths,
6151            mode=mode,
6152        )
6153        os.makedirs(plot_path, exist_ok=True)
6154        task.visualize_results(
6155            save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"),
6156            add_legend=add_legend,
6157            ground_truth=ground_truth,
6158            colormap=colormap,
6159            hide_axes=hide_axes,
6160            min_classes=min_classes,
6161            whole_video=whole_video,
6162            transparent=transparent,
6163            dataset=mode,
6164            drop_classes=drop_classes,
6165            search_classes=search_classes,
6166            width=width,
6167            smooth_interval_prediction=smooth_interval_prediction,
6168            font_size=font_size,
6169            num_plots=num_plots,
6170            window_size=window_size,
6171        )
6172        if remove_saved_features:
6173            self._remove_stores(parameters)
6174
6175    def create_video_from_labels(
6176        self,
6177        video_dir_path: str,
6178        mode="ground_truth",
6179        prediction_name: str = None,
6180        save_path: str = None,
6181    ):
6182        if save_path is None:
6183            save_path = os.path.join(
6184                self.project_path, "results", f"annotated_videos_from_{mode}"
6185            )
6186        os.makedirs(save_path, exist_ok=True)
6187
6188        params = self._read_parameters(catch_blanks=False)
6189
6190        if mode == "ground_truth":
6191            source_dir = self.annotation_path
6192            annotation_suffix = params["data"]["annotation_suffix"]
6193        elif mode == "prediction":
6194            assert (
6195                not prediction_name is None
6196            ), "Please provide a prediction name with mode 'prediction'"
6197            source_dir = os.path.join(
6198                self.project_path, "results", "predictions", prediction_name
6199            )
6200            annotation_suffix = f"_{prediction_name}_prediction.pickle"
6201
6202        video_annotation_pairs = [
6203            (
6204                os.path.join(video_dir_path, f),
6205                os.path.join(
6206                    source_dir, f.replace(f.split(".")[-1], annotation_suffix)
6207                ),
6208            )
6209            for f in os.listdir(video_dir_path)
6210            if os.path.exists(
6211                os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix))
6212            )
6213        ]
6214
6215        for video_file, annotation_file in tqdm(video_annotation_pairs):
6216            if not os.path.exists(video_file):
6217                print(f"Video file {video_file} does not exist, skipping.")
6218                continue
6219            if not os.path.exists(annotation_file):
6220                print(f"Annotation file {annotation_file} does not exist, skipping.")
6221                continue
6222
6223            if annotation_file.endswith(".pickle"):
6224                annotations = load_pickle(annotation_file)
6225            elif annotation_file.endswith(".csv"):
6226                annotations = pd.read_csv(annotation_file)
6227
6228            if mode == "ground_truth":
6229                behaviors = annotations[1]
6230                annot_data = annotations[3]
6231            elif mode == "predictions":
6232                behaviors = list(annotations["classes"].values())
6233                annot_data = [
6234                    annotations[key]
6235                    for key in annotations.keys()
6236                    if key not in ["classes", "min_frame", "max_frame"]
6237                ]
6238                if params["general"]["exclusive"]:
6239                    annot_data = [np.argmax(annot, axis=1) for annot in annot_data]
6240                    seqs = [
6241                        [
6242                            self._bin_array_to_sequences(annot, target_value=k)
6243                            for k in range(len(behaviors))
6244                        ]
6245                        for annot in annot_data
6246                    ]
6247                else:
6248                    annot_data = [np.where(annot > 0.5)[0] for annot in annot_data]
6249                    seqs = [
6250                        self._bin_array_to_sequences(annot, target_value=1)
6251                        for annot in annot_data
6252                    ]
6253                annotations = ["", "", seqs]
6254
6255            for individual in annotations[3]:
6256                for behavior in annotations[3][individual]:
6257                    intervals = annotations[3][individual][behavior]
6258                    self._extract_videos(
6259                        video_file,
6260                        intervals,
6261                        behavior,
6262                        individual,
6263                        save_path,
6264                        resolution=(640, 480),
6265                        fps=30,
6266                    )
6267
6268    def _bin_array_to_sequences(
6269        self, annot_data: List[np.ndarray], target_value: int
6270    ) -> List[List[Tuple[int, int]]]:
6271        is_target = annot_data == target_value
6272        changes = np.diff(np.concatenate(([False], is_target, [False])))
6273        indices = np.where(changes)[0].reshape(-1, 2)
6274        subsequences = [list(range(start, end)) for start, end in indices]
6275        return subsequences
6276
6277    def _extract_videos(
6278        self,
6279        video_file: str,
6280        intervals: np.ndarray,
6281        behavior: str,
6282        individual: str,
6283        video_dir: str,
6284        resolution: Tuple[int, int] = (640, 480),
6285        fps: int = 30,
6286    ) -> None:
6287        """Extract frames from a video file from frames in between intervals in behavior folder for a given individual"""
6288        cap = cv2.VideoCapture(video_file)
6289        print("Extracting frames from", video_file)
6290
6291        for start, end, confusion in tqdm(intervals):
6292
6293            frame_count = start
6294            assert start < end, "Start frame should be less than end frame"
6295            if confusion > 0.5:
6296                continue
6297            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6298            output_file = os.path.join(
6299                video_dir,
6300                individual,
6301                behavior,
6302                os.path.splitext(os.path.basename(video_file))[0]
6303                + f"vid_{individual}_{behavior}_{start:05d}_{end:05d}.mp4",
6304            )
6305            fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # Codec, e.g., 'XVID', 'MJPG'
6306            out = cv2.VideoWriter(
6307                output_file, fourcc, fps, (resolution[0], resolution[1])
6308            )
6309            while cap.isOpened():
6310                ret, frame = cap.read()
6311                if not ret:
6312                    break
6313
6314                # Resize large frames
6315                frame = cv2.resize(frame, (640, 480))
6316                out.write(frame)
6317
6318                frame_count += 1
6319                # Break if end frame is reached or max frames per behavior is reached
6320                if frame_count == end:
6321                    break
6322            if frame_count <= 2:
6323                os.remove(output_file)
6324            # cap.release()
6325            out.release()
6326
6327    def create_metadata_backup(self) -> None:
6328        """Create a copy of the meta files."""
6329        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6330        meta_path = os.path.join(self.project_path, "meta")
6331        if os.path.exists(meta_copy_path):
6332            shutil.rmtree(meta_copy_path)
6333        os.mkdir(meta_copy_path)
6334        for file in os.listdir(meta_path):
6335            if file == "backup":
6336                continue
6337            if os.path.isdir(os.path.join(meta_path, file)):
6338                continue
6339            shutil.copy(
6340                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
6341            )
6342
6343    def load_metadata_backup(self) -> None:
6344        """Load from previously created meta data backup (in case of corruption)."""
6345        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6346        meta_path = os.path.join(self.project_path, "meta")
6347        for file in os.listdir(meta_copy_path):
6348            shutil.copy(
6349                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
6350            )
6351
6352    def get_behavior_dictionary(self, episode_name: str) -> Dict:
6353        """Get the behavior dictionary for an episode.
6354
6355        Parameters
6356        ----------
6357        episode_name : str
6358            the name of the episode
6359
6360        Returns
6361        -------
6362        behaviors_dictionary : dict
6363            a dictionary where keys are label indices and values are label names
6364
6365        """
6366        return self._episode(episode_name).get_behaviors_dict()
6367
6368    def import_episodes(
6369        self,
6370        episodes_directory: str,
6371        name_map: Dict = None,
6372        repeat_policy: str = "error",
6373    ) -> None:
6374        """Import episodes exported with `Project.export_episodes`.
6375
6376        Parameters
6377        ----------
6378        episodes_directory : str
6379            the path to the exported episodes directory
6380        name_map : dict, optional
6381            a name change dictionary for the episodes: keys are old names, values are new names
6382        repeat_policy : {'error', 'skip', 'force'}, default 'error'
6383            the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates,
6384            'force' overwrites existing episodes
6385
6386        """
6387        if name_map is None:
6388            name_map = {}
6389        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
6390        to_remove = []
6391        import_string = "Imported episodes: "
6392        for episode_name in episodes.index:
6393            if episode_name in name_map:
6394                import_string += f"{episode_name} "
6395                episode_name = name_map[episode_name]
6396                import_string += f"({episode_name}), "
6397            else:
6398                import_string += f"{episode_name}, "
6399            try:
6400                self._check_episode_validity(episode_name, allow_doublecolon=True)
6401            except ValueError as e:
6402                if str(e).endswith("is already taken!"):
6403                    if repeat_policy == "skip":
6404                        to_remove.append(episode_name)
6405                    elif repeat_policy == "force":
6406                        self.remove_episode(episode_name)
6407                    elif repeat_policy == "error":
6408                        raise ValueError(
6409                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
6410                        )
6411                    else:
6412                        raise ValueError(
6413                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']"
6414                        )
6415        episodes = episodes.drop(index=to_remove)
6416        self._episodes().update(
6417            episodes,
6418            name_map=name_map,
6419            force=(repeat_policy == "force"),
6420            data_path=self.data_path,
6421            annotation_path=self.annotation_path,
6422        )
6423        for episode_name in episodes.index:
6424            if episode_name in name_map:
6425                new_episode_name = name_map[episode_name]
6426            else:
6427                new_episode_name = episode_name
6428            model_dir = os.path.join(
6429                self.project_path, "results", "model", new_episode_name
6430            )
6431            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
6432            if os.path.exists(model_dir):
6433                shutil.rmtree(model_dir)
6434            os.mkdir(model_dir)
6435            for file in os.listdir(old_model_dir):
6436                shutil.copyfile(
6437                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
6438                )
6439            log_file = os.path.join(
6440                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6441            )
6442            old_log_file = os.path.join(
6443                episodes_directory, "logs", f"{episode_name}.txt"
6444            )
6445            shutil.copyfile(old_log_file, log_file)
6446        print(import_string)
6447        print("\n")
6448
6449    def export_episodes(
6450        self, episode_names: List, output_directory: str, name: str = None
6451    ) -> None:
6452        """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`.
6453
6454        Parameters
6455        ----------
6456        episode_names : list
6457            a list of string episode names
6458        output_directory : str
6459            the path to the directory where the episodes will be saved
6460        name : str, optional
6461            the name of the episodes directory (by default `exported_episodes`)
6462
6463        """
6464        if name is None:
6465            name = "exported_episodes"
6466        if os.path.exists(
6467            os.path.join(output_directory, name + ".zip")
6468        ) or os.path.exists(os.path.join(output_directory, name)):
6469            i = 1
6470            while os.path.exists(
6471                os.path.join(output_directory, name + f"_{i}.zip")
6472            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
6473                i += 1
6474            name = name + f"_{i}"
6475        dest_dir = os.path.join(output_directory, name)
6476        os.mkdir(dest_dir)
6477        os.mkdir(os.path.join(dest_dir, "model"))
6478        os.mkdir(os.path.join(dest_dir, "logs"))
6479        runs = []
6480        for episode in episode_names:
6481            runs += self._episodes().get_runs(episode)
6482        for run in runs:
6483            shutil.copytree(
6484                os.path.join(self.project_path, "results", "model", run),
6485                os.path.join(dest_dir, "model", run),
6486            )
6487            shutil.copyfile(
6488                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
6489                os.path.join(dest_dir, "logs", f"{run}.txt"),
6490            )
6491        data = self._episodes().get_subset(runs)
6492        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))
6493
6494    def get_results_table(
6495        self,
6496        episode_names: List,
6497        metrics: List = None,
6498        mode: str = "mean",  # Choose between ["mean", "statistics", "detail"]
6499        print_results: bool = True,
6500        classes: List = None,
6501    ):
6502        """Generate a `pandas` dataframe with a summary of episode results.
6503
6504        Parameters
6505        ----------
6506        episode_names : list
6507            a list of names of episodes to include
6508        metrics : list, optional
6509            a list of metric names to include
6510        mode : bool, optional
6511            the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean"
6512        print_results : bool, optional
6513            if True, the results will be printed to the console, by default True
6514        classes : list, optional
6515            a list of names of classes to include (by default all are included)
6516
6517        Returns
6518        -------
6519        results : pd.DataFrame
6520            a table with the results
6521
6522        """
6523        run_names = []
6524        for episode in episode_names:
6525            run_names += self._episodes().get_runs(episode)
6526        episodes = self.list_episodes(run_names, print_results=False)
6527        metric_columns = [x for x in episodes.columns if x[0] == "results"]
6528        results_df = pd.DataFrame()
6529        if metrics is not None:
6530            metric_columns = [
6531                x for x in metric_columns if x[1].split("_")[0] in metrics
6532            ]
6533        for episode in episode_names:
6534            results = []
6535            metric_set = set()
6536            for run in self._episodes().get_runs(episode):
6537                beh_dict = self.get_behavior_dictionary(run)
6538                res_dict = defaultdict(lambda: {})
6539                for column in metric_columns:
6540                    if np.isnan(episodes.loc[run, column]):
6541                        continue
6542                    split = column[1].split("_")
6543                    if split[-1].isnumeric():
6544                        beh_ind = int(split[-1])
6545                        metric_name = "_".join(split[:-1])
6546                        beh = beh_dict[beh_ind]
6547                    else:
6548                        beh = "average"
6549                        metric_name = column[1]
6550                    res_dict[beh][metric_name] = episodes.loc[run, column]
6551                    metric_set.add(metric_name)
6552                if "average" not in res_dict:
6553                    res_dict["average"] = {}
6554                for metric in metric_set:
6555                    if metric not in res_dict["average"]:
6556                        arr = [
6557                            res_dict[beh][metric]
6558                            for beh in res_dict
6559                            if metric in res_dict[beh]
6560                        ]
6561                        res_dict["average"][metric] = np.mean(arr)
6562                results.append(res_dict)
6563            episode_results = {}
6564            for metric in metric_set:
6565                for beh in results[0].keys():
6566                    if classes is not None and beh not in classes:
6567                        continue
6568                    arr = []
6569                    for res_dict in results:
6570                        if metric in res_dict[beh]:
6571                            arr.append(res_dict[beh][metric])
6572                    if len(arr) > 0:
6573                        if mode == "statistics":
6574                            episode_results[(beh, f"{episode} {metric} mean")] = (
6575                                np.mean(arr)
6576                            )
6577                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
6578                                arr
6579                            )
6580                        elif mode == "mean":
6581                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
6582                        elif mode == "detail":
6583                            for i, val in enumerate(arr):
6584                                episode_results[(beh, f"{episode}::{i} {metric}")] = val
6585            for key, value in episode_results.items():
6586                results_df.loc[key[0], key[1]] = value
6587        if print_results:
6588            print(f"RESULTS:")
6589            print(results_df)
6590            print("\n")
6591        return results_df
6592
6593    def episode_exists(self, episode_name: str) -> bool:
6594        """Check if an episode already exists.
6595
6596        Parameters
6597        ----------
6598        episode_name : str
6599            the episode name
6600
6601        Returns
6602        -------
6603        exists : bool
6604            `True` if the episode exists
6605
6606        """
6607        return self._episodes().check_name_validity(episode_name)
6608
6609    def search_exists(self, search_name: str) -> bool:
6610        """Check if a search already exists.
6611
6612        Parameters
6613        ----------
6614        search_name : str
6615            the search name
6616
6617        Returns
6618        -------
6619        exists : bool
6620            `True` if the search exists
6621
6622        """
6623        return self._searches().check_name_validity(search_name)
6624
6625    def prediction_exists(self, prediction_name: str) -> bool:
6626        """Check if a prediction already exists.
6627
6628        Parameters
6629        ----------
6630        prediction_name : str
6631            the prediction name
6632
6633        Returns
6634        -------
6635        exists : bool
6636            `True` if the prediction exists
6637
6638        """
6639        return self._predictions().check_name_validity(prediction_name)
6640
6641    @staticmethod
6642    def project_name_available(projects_path: str, project_name: str):
6643        """Check if a project name is available.
6644
6645        Parameters
6646        ----------
6647        projects_path : str
6648            the path to the projects directory
6649        project_name : str
6650            the name of the project to check
6651
6652        Returns
6653        -------
6654        available : bool
6655            `True` if the project name is available
6656
6657        """
6658        if projects_path is None:
6659            projects_path = os.path.join(str(Path.home()), "DLC2Action")
6660        return not os.path.exists(os.path.join(projects_path, project_name))
6661
6662    def _update_episode_metrics(self, episode_name: str, metrics: Dict):
6663        """Update meta data with evaluation results.
6664
6665        Parameters
6666        ----------
6667        episode_name : str
6668            the name of the episode
6669        metrics : dict
6670            the metrics dictionary to update with
6671
6672        """
6673        self._episodes().update_episode_metrics(episode_name, metrics)
6674
6675    def rename_episode(self, episode_name: str, new_episode_name: str):
6676        """Rename an episode.
6677
6678        Parameters
6679        ----------
6680        episode_name : str
6681            the current episode name
6682        new_episode_name : str
6683            the new episode name
6684
6685        """
6686        shutil.move(
6687            os.path.join(self.project_path, "results", "model", episode_name),
6688            os.path.join(self.project_path, "results", "model", new_episode_name),
6689        )
6690        shutil.move(
6691            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
6692            os.path.join(
6693                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6694            ),
6695        )
6696        self._episodes().rename_episode(episode_name, new_episode_name)
6697
6698
6699class _Runner:
6700    """A helper class for running hyperparameter searches."""
6701
6702    def __init__(
6703        self,
6704        search_name: str,
6705        search_space: Dict,
6706        load_episode: str,
6707        load_epoch: int,
6708        metric: str,
6709        average: int,
6710        task: Union[TaskDispatcher, None],
6711        remove_saved_features: bool,
6712        project: Project,
6713    ):
6714        """Initialize the class.
6715
6716        Parameters
6717        ----------
6718        task : TaskDispatcher
6719            the task dispatcher object
6720        search_name : str
6721            the name the search should be saved under
6722        search_space : dict
6723            a dictionary representing the search space; of this general structure:
6724            {'group/param_name': ('float/int/float_log/int_log', start, end),
6725            'group/param_name': ('categorical', [choices])}, e.g.
6726            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
6727            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
6728        load_episode : str
6729            the name of the episode to load the model from
6730        load_epoch : int
6731            the epoch to load the model from (if not provided, the last checkpoint is used)
6732        metric : str
6733            the metric to maximize/minimize (see direction)
6734        average : int
6735            the number of epochs to average the metric; if 0, the last value is taken
6736        remove_saved_features : bool
6737            if `True`, the old datasets will be deleted when data parameters change
6738        project : Project
6739            the parent `Project` instance
6740
6741        """
6742        self.search_space = search_space
6743        self.load_episode = load_episode
6744        self.load_epoch = load_epoch
6745        self.metric = metric
6746        self.average = average
6747        self.feature_save_path = None
6748        self.remove_saved_featuress = remove_saved_features
6749        self.save_stores = project._save_stores
6750        self.remove_datasets = project.remove_saved_features
6751        self.task = task
6752        self.search_name = search_name
6753        self.update = project._update
6754        self.remove_episode = project.remove_episode
6755        self.fill = project._fill
6756
6757    def clean(self):
6758        """Remove datasets if needed.
6759
6760        This method removes saved feature datasets when the remove_saved_features flag is set.
6761
6762        """
6763        if self.remove_saved_featuress:
6764            self.remove_datasets([os.path.basename(self.feature_save_path)])
6765
6766    def run(self, trial, parameters):
6767        """Make a trial run.
6768
6769        Parameters
6770        ----------
6771        trial : optuna.trial.Trial
6772            the Optuna trial object
6773        parameters : dict
6774            the base parameters dictionary
6775
6776        Returns
6777        -------
6778        value : float
6779            the metric value for this trial
6780
6781        """
6782        params = deepcopy(parameters)
6783        param_update = defaultdict(
6784            lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: {})))
6785        )
6786        for full_name, space in self.search_space.items():
6787            group, param_name = (
6788                full_name.split("/")[0],
6789                "/".join(full_name.split("/")[1:]),
6790            )
6791            log = space[0][-3:] == "log"
6792            if space[0].startswith("int"):
6793                value = trial.suggest_int(full_name, space[1], space[2], log=log)
6794            elif space[0].startswith("float"):
6795                value = trial.suggest_float(full_name, space[1], space[2], log=log)
6796            elif space[0] == "categorical":
6797                value = trial.suggest_categorical(full_name, space[1])
6798            else:
6799                raise ValueError(
6800                    "The search space has to be formatted as either "
6801                    '("float"/"int"/"float_log"/"int_log", start, end) '
6802                    f'or ("categorical", [choices]); got {space} for {group}/{param_name}'
6803                )
6804            if len(param_name.split("/")) == 1:
6805                param_update[group][param_name] = value
6806            else:
6807                pars = param_name.split("/")
6808                pars = [int(x) if x.isnumeric() else x for x in pars]
6809                if len(pars) == 2:
6810                    param_update[group][pars[0]][pars[1]] = value
6811                elif len(pars) == 3:
6812                    param_update[group][pars[0]][pars[1]][pars[2]] = value
6813                elif len(pars) == 4:
6814                    param_update[group][pars[0]][pars[1]][pars[2]][pars[3]] = value
6815        param_update = {k: dict(v) for k, v in param_update.items()}
6816        params = self.update(params, param_update)
6817        self.remove_episode(f"_{self.search_name}")
6818        params = self.fill(
6819            params,
6820            f"_{self.search_name}",
6821            self.load_episode,
6822            load_epoch=self.load_epoch,
6823            only_load_model=True,
6824        )
6825        if self.feature_save_path != params["data"]["feature_save_path"]:
6826            if self.feature_save_path is not None:
6827                self.clean()
6828            self.feature_save_path = params["data"]["feature_save_path"]
6829        self.save_stores(params)
6830        if self.task is None:
6831            self.task = TaskDispatcher(deepcopy(params))
6832        else:
6833            self.task.update_task(params)
6834
6835        _, metrics_log = self.task.train(trial, self.metric)
6836        if self.metric in metrics_log["val"].keys():
6837            metric_values = metrics_log["val"][self.metric]
6838            if self.average > 0:
6839                value = np.mean(sorted(metric_values)[-self.average :])
6840            else:
6841                value = metric_values[-1]
6842            return value
6843        else:  # ['accuracy', 'precision', 'f1', 'recall', 'count', 'segmental_precision', 'segmental_recall', 'segmental_f1', 'edit_distance', 'f_beta', 'segmental_f_beta', 'semisegmental_precision', 'semisegmental_recall', 'semisegmental_f1', 'pr-auc', 'semisegmental_pr-auc', 'mAP']
6844            if self.metric in [
6845                "f1",
6846                "precision",
6847                "recall",
6848                "accuracy",
6849                "count",
6850                "segmental_precision",
6851                "segmental_recall",
6852                "segmental_f1",
6853                "f_beta",
6854                "segmental_f_beta",
6855                "semisegmental_precision",
6856                "semisegmental_recall",
6857                "semisegmental_f1",
6858                "pr-auc",
6859                "semisegmental_pr-auc",
6860                "mAP",
6861            ]:
6862                return 0
6863            elif self.metric in ["loss", "mse", "mae", "edit_distance"]:
6864                return float("inf")
class Project:
  55class Project:
  56    """A class to create and maintain the project files + keep track of experiments."""
  57
  58    def __init__(
  59        self,
  60        name: str,
  61        data_type: str = None,
  62        annotation_type: str = "none",
  63        projects_path: str = None,
  64        data_path: Union[str, List] = None,
  65        annotation_path: Union[str, List] = None,
  66        copy: bool = False,
  67    ) -> None:
  68        """Initialize the class.
  69
  70        Parameters
  71        ----------
  72        name : str
  73            name of the project
  74        data_type : str, optional
  75            data type (run Project.data_types() to see available options; has to be provided if the project is being
  76            created)
  77        annotation_type : str, default 'none'
  78            annotation type (run Project.annotation_types() to see available options)
  79        projects_path : str, optional
  80            path to the projects folder (is filled with ~/DLC2Action by default)
  81        data_path : str, optional
  82            path to the folder containing input files for the project (has to be provided if the project is being
  83            created)
  84        annotation_path : str, optional
  85            path to the folder containing annotation files for the project
  86        copy : bool, default False
  87            if True, the files from annotation_path and data_path will be copied to the projects folder;
  88            otherwise they will be moved
  89
  90        """
  91        if projects_path is None:
  92            projects_path = os.path.join(str(Path.home()), "DLC2Action")
  93        if not os.path.exists(projects_path):
  94            os.mkdir(projects_path)
  95        self.project_path = os.path.join(projects_path, name)
  96        self.name = name
  97        self.data_type = data_type
  98        self.annotation_type = annotation_type
  99        self.data_path = data_path
 100        self.annotation_path = annotation_path
 101        if not os.path.exists(self.project_path):
 102            if data_type is None:
 103                raise ValueError(
 104                    "The data_type parameter is necessary when creating a new project!"
 105                )
 106            self._initialize_project(
 107                data_type, annotation_type, data_path, annotation_path, copy
 108            )
 109        else:
 110            self.annotation_type, self.data_type = self._read_types()
 111            if data_type != self.data_type and data_type is not None:
 112                raise ValueError(
 113                    f"The project has already been initialized with data_type={self.data_type}!"
 114                )
 115            if annotation_type != self.annotation_type and annotation_type != "none":
 116                raise ValueError(
 117                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
 118                )
 119            self.annotation_path, data_path = self._read_paths()
 120            if self.data_path is None:
 121                self.data_path = data_path
 122            # if data_path != self.data_path and data_path is not None:
 123            #     raise ValueError(
 124            #         f"The project has already been initialized with data_path={self.data_path}!"
 125            #     )
 126            if annotation_path != self.annotation_path and annotation_path is not None:
 127                raise ValueError(
 128                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
 129                )
 130        self._update_configs()
 131
 132    def _make_prediction(
 133        self,
 134        prediction_name: str,
 135        episode_names: List,
 136        load_epochs: Union[List[int], int] = None,
 137        parameters_update: Dict = None,
 138        data_path: str = None,
 139        file_paths: Set = None,
 140        mode: str = "all",
 141        augment_n: int = 0,
 142        evaluate: bool = False,
 143        task: TaskDispatcher = None,
 144        embedding: bool = False,
 145        annotation_type: str = "none",
 146    ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]:
 147        """Generate a prediction.
 148        Parameters
 149        ----------
 150        prediction_name : str
 151            name of the prediction
 152            episode_names : List
 153            names of the episodes to use for the prediction
 154            load_epochs : Union[List[int],int], optional
 155            epochs to load for each episode; if a single integer is provided, it will be used for all episodes;
 156            if None, the last epochs will be used
 157            parameters_update : Dict, optional
 158            dictionary with parameters to update the task parameters
 159            data_path : str, optional
 160            path to the data folder; if None, the data_path from the project will be used
 161            file_paths : Set, optional
 162            set of file paths to use for the prediction; if None, the data_path will be used
 163            mode : str, default "all
 164            mode of the prediction; can be "train", "val", "test" or "all
 165            augment_n : int, default 0
 166            number of augmentations to apply to the data; if 0, no augmentations are applied
 167            evaluate : bool, default False
 168            if True, the prediction will be evaluated and the results will be saved to the episode meta file
 169            task : TaskDispatcher, optional
 170            task object to use for the prediction; if None, a new task object will be created
 171            embedding : bool, default False
 172            if True, the prediction will be returned as an embedding
 173            annotation_type : str, default "none
 174            type of the annotation to use for the prediction; if "none", the annotation will not be used
 175        Returns
 176        -------
 177        task : TaskDispatcher
 178            task object used for the prediction
 179        parameters : Dict
 180            parameters used for the prediction
 181        mode : str
 182            mode of the prediction
 183        prediction : torch.Tensor
 184            prediction tensor of shape (num_videos, num_behaviors, num_frames)
 185        inference_time : str
 186            time taken for the prediction in the format "HH:MM:SS"
 187        behavior_dict : Dict
 188            dictionary with behavior names and their indices
 189        """
 190
 191        names = []
 192        for episode_name in episode_names:
 193            names += self._episodes().get_runs(episode_name)
 194        if len(names) == 0:
 195            warnings.warn(f"None of the episodes {episode_names} exist!")
 196            names = [None]
 197        if load_epochs is None:
 198            load_epochs = [None for _ in names]
 199        elif isinstance(load_epochs, int):
 200            load_epochs = [load_epochs for _ in names]
 201        assert len(load_epochs) == len(
 202            names
 203        ), f"Length of load_epochs ({len(load_epochs)}) must match the number of episodes ({len(names)})!"
 204        prediction = None
 205        decision_thresholds = None
 206        time_total = 0
 207        behavior_dicts = [
 208            self.get_behavior_dictionary(episode_name) for episode_name in names
 209        ]
 210
 211        if not all(
 212            [
 213                set(d.values()) == set(behavior_dicts[0].values())
 214                for d in behavior_dicts[1:]
 215            ]
 216        ):
 217            raise ValueError(
 218                f"Episodes {episode_names} have different sets of behaviors!"
 219            )
 220        behaviors = list(behavior_dicts[0].values())
 221
 222        for episode_name, load_epoch, behavior_dict in zip(
 223            names, load_epochs, behavior_dicts
 224        ):
 225            print(f"episode {episode_name}")
 226            task, parameters, data_mode = self._make_task_prediction(
 227                prediction_name=prediction_name,
 228                load_episode=episode_name,
 229                parameters_update=parameters_update,
 230                load_epoch=load_epoch,
 231                data_path=data_path,
 232                mode=mode,
 233                file_paths=file_paths,
 234                task=task,
 235                decision_thresholds=decision_thresholds,
 236                annotation_type=annotation_type,
 237            )
 238            # data_mode = "train" if mode == "all" else mode
 239            time_start = time.time()
 240            new_pred = task.predict(
 241                data_mode,
 242                raw_output=True,
 243                apply_primary_function=True,
 244                augment_n=augment_n,
 245                embedding=embedding,
 246            )
 247            indices = [
 248                behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1])
 249            ]
 250            new_pred = new_pred[:, indices, :]
 251            time_end = time.time()
 252            time_total += time_end - time_start
 253            if evaluate:
 254                _, metrics = task.evaluate_prediction(
 255                    new_pred, data=data_mode, indices=indices
 256                )
 257                if mode == "val":
 258                    self._update_episode_metrics(episode_name, metrics)
 259            if prediction is None:
 260                prediction = new_pred
 261            else:
 262                prediction += new_pred
 263            print("\n")
 264        hours = int(time_total // 3600)
 265        time_total -= hours * 3600
 266        minutes = int(time_total // 60)
 267        time_total -= minutes * 60
 268        seconds = int(time_total)
 269        inference_time = f"{hours}:{minutes:02}:{seconds:02}"
 270        prediction /= len(names)
 271        return (
 272            task,
 273            parameters,
 274            data_mode,
 275            prediction,
 276            inference_time,
 277            behavior_dicts[0],
 278        )
 279
 280    def _make_task_prediction(
 281        self,
 282        prediction_name: str,
 283        load_episode: str = None,
 284        parameters_update: Dict = None,
 285        load_epoch: int = None,
 286        data_path: str = None,
 287        annotation_path: str = None,
 288        mode: str = "val",
 289        file_paths: Set = None,
 290        decision_thresholds: List = None,
 291        task: TaskDispatcher = None,
 292        annotation_type: str = "none",
 293    ) -> Tuple[TaskDispatcher, Dict, str]:
 294        """Make a `TaskDispatcher` object that will be used to generate a prediction."""
 295        if parameters_update is None:
 296            parameters_update = {}
 297        parameters_update_second = {}
 298        if mode == "all" or data_path is not None or file_paths is not None:
 299            parameters_update_second["training"] = {
 300                "val_frac": 0,
 301                "test_frac": 0,
 302                "partition_method": "random",
 303                "save_split": False,
 304                "split_path": None,
 305            }
 306            mode = "train"
 307        if decision_thresholds is not None:
 308            if (
 309                len(decision_thresholds)
 310                == self._episode(load_episode).get_num_classes()
 311            ):
 312                parameters_update_second["general"] = {
 313                    "threshold_value": decision_thresholds
 314                }
 315            else:
 316                raise ValueError(
 317                    f"The length of the decision thresholds {decision_thresholds} "
 318                    f"must be equal to the length of the behaviors dictionary "
 319                    f"{self._episode(load_episode).get_behaviors_dict()}"
 320                )
 321        data_param_update = {}
 322        if data_path is not None:
 323            data_param_update = {"data_path": data_path}
 324            if annotation_path is None:
 325                data_param_update["annotation_path"] = data_path
 326        if annotation_path is not None:
 327            data_param_update["annotation_path"] = annotation_path
 328        if file_paths is not None:
 329            data_param_update = {"data_path": None, "file_paths": file_paths}
 330        parameters_update = self._update(parameters_update, {"data": data_param_update})
 331        if data_path is not None or file_paths is not None:
 332            general_update = {
 333                "annotation_type": annotation_type,
 334                "only_load_annotated": False,
 335            }
 336        else:
 337            general_update = {}
 338        parameters_update = self._update(parameters_update, {"general": general_update})
 339        task, parameters = self._make_task(
 340            episode_name=prediction_name,
 341            load_episode=load_episode,
 342            parameters_update=parameters_update,
 343            parameters_update_second=parameters_update_second,
 344            load_epoch=load_epoch,
 345            purpose="prediction",
 346            task=task,
 347        )
 348        return task, parameters, mode
 349
 350    def _make_task_training(
 351        self,
 352        episode_name: str,
 353        load_episode: str = None,
 354        parameters_update: Dict = None,
 355        load_epoch: int = None,
 356        load_search: str = None,
 357        load_parameters: list = None,
 358        round_to_binary: list = None,
 359        load_strict: bool = True,
 360        continuing: bool = False,
 361        task: TaskDispatcher = None,
 362        mask_name: str = None,
 363        throwaway: bool = False,
 364    ) -> Tuple[TaskDispatcher, Dict, str]:
 365        """Make a `TaskDispatcher` object that will be used to generate a prediction."""
 366        if parameters_update is None:
 367            parameters_update = {}
 368        if continuing:
 369            purpose = "continuing"
 370        else:
 371            purpose = "training"
 372        if mask_name is not None:
 373            mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle")
 374        parameters_update_second = {"data": {"real_lens": mask_name}}
 375        if throwaway:
 376            parameters_update = self._update(
 377                parameters_update, {"training": {"normalize": False, "device": "cpu"}}
 378            )
 379        return self._make_task(
 380            episode_name,
 381            load_episode,
 382            parameters_update,
 383            parameters_update_second,
 384            load_epoch,
 385            load_search,
 386            load_parameters,
 387            round_to_binary,
 388            purpose,
 389            task,
 390            load_strict=load_strict,
 391        )
 392
 393    def _make_parameters(
 394        self,
 395        episode_name: str,
 396        load_episode: str = None,
 397        parameters_update: Dict = None,
 398        parameters_update_second: Dict = None,
 399        load_epoch: int = None,
 400        load_search: str = None,
 401        load_parameters: list = None,
 402        round_to_binary: list = None,
 403        purpose: str = "train",
 404        load_strict: bool = True,
 405    ):
 406        """Construct a parameters dictionary."""
 407        if parameters_update is None:
 408            parameters_update = {}
 409        pars_update = deepcopy(parameters_update)
 410        if parameters_update_second is None:
 411            parameters_update_second = {}
 412        if (
 413            purpose == "prediction"
 414            and "model" in pars_update.keys()
 415            and pars_update["general"]["model_name"] != "motionbert"
 416        ):
 417            raise ValueError("Cannot change model parameters after training!")
 418        if purpose in ["continuing", "prediction"] and load_episode is not None:
 419            read_parameters = self._read_parameters()
 420            parameters = self._episodes().load_parameters(load_episode)
 421            parameters["metrics"] = self._update(
 422                read_parameters["metrics"], parameters["metrics"]
 423            )
 424            parameters["ssl"] = self._update(
 425                read_parameters["ssl"], parameters.get("ssl", {})
 426            )
 427        else:
 428            parameters = self._read_parameters()
 429        if "model" in pars_update:
 430            model_params = pars_update.pop("model")
 431        else:
 432            model_params = None
 433        if "features" in pars_update:
 434            feat_params = pars_update.pop("features")
 435        else:
 436            feat_params = None
 437        if "augmentations" in pars_update:
 438            aug_params = pars_update.pop("augmentations")
 439        else:
 440            aug_params = None
 441        parameters = self._update(parameters, pars_update)
 442        if pars_update.get("general", {}).get("model_name") is not None:
 443            model_name = parameters["general"]["model_name"]
 444            parameters["model"] = self._open_yaml(
 445                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
 446            )
 447        if pars_update.get("general", {}).get("feature_extraction") is not None:
 448            feat_name = parameters["general"]["feature_extraction"]
 449            parameters["features"] = self._open_yaml(
 450                os.path.join(
 451                    self.project_path, "config", "features", f"{feat_name}.yaml"
 452                )
 453            )
 454            aug_name = options.extractor_to_transformer[
 455                parameters["general"]["feature_extraction"]
 456            ]
 457            parameters["augmentations"] = self._open_yaml(
 458                os.path.join(
 459                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
 460                )
 461            )
 462        if model_params is not None:
 463            parameters["model"] = self._update(parameters["model"], model_params)
 464        if feat_params is not None:
 465            parameters["features"] = self._update(parameters["features"], feat_params)
 466        if aug_params is not None:
 467            parameters["augmentations"] = self._update(
 468                parameters["augmentations"], aug_params
 469            )
 470        if load_search is not None:
 471            parameters = self._update_with_search(
 472                parameters, load_search, load_parameters, round_to_binary
 473            )
 474        parameters = self._fill(
 475            parameters,
 476            episode_name,
 477            load_episode,
 478            load_epoch=load_epoch,
 479            load_strict=load_strict,
 480            only_load_model=(purpose != "continuing"),
 481            continuing=(purpose in ["prediction", "continuing"]),
 482            enforce_split_parameters=(purpose == "prediction"),
 483        )
 484        parameters = self._update(parameters, parameters_update_second)
 485        return parameters
 486
 487    def _make_task(
 488        self,
 489        episode_name: str,
 490        load_episode: str = None,
 491        parameters_update: Dict = None,
 492        parameters_update_second: Dict = None,
 493        load_epoch: int = None,
 494        load_search: str = None,
 495        load_parameters: list = None,
 496        round_to_binary: list = None,
 497        purpose: str = "train",
 498        task: TaskDispatcher = None,
 499        load_strict: bool = True,
 500    ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]:
 501        """Make a `TaskDispatcher` object.
 502
 503        The task parameters are read from the config files and then updated with the
 504        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 505        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 506        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 507        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 508        data parameters are used.
 509
 510        Parameters
 511        ----------
 512        episode_name : str
 513            the name of the episode
 514        load_episode : str, optional
 515            the (previously run) episode name to load the model from
 516        parameters_update : dict, optional
 517            the dictionary used to update the parameters from the config
 518        parameters_update_second : dict, optional
 519            the dictionary used to update the parameters after the automatic fill-out
 520        load_epoch : int, optional
 521            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 522        load_search : str, optional
 523            the hyperparameter search result to load
 524        load_parameters : list, optional
 525            a list of string names of the parameters to load from load_search (if not provided, all parameters
 526            are loaded)
 527        round_to_binary : list, optional
 528            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 529        purpose : {"train", "continuing", "prediction"}
 530            the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing
 531            the training of an interrupted episode, `"prediction"` for generating a prediction)
 532        task : TaskDispatcher, optional
 533            a pre-existing task; if provided, the method will update the task instead of creating a new one
 534            (this might save time, mainly on dataset loading)
 535
 536        Returns
 537        -------
 538        task : TaskDispatcher
 539            the `TaskDispatcher` instance
 540        parameters : dict
 541            the parameters dictionary that describes the task
 542
 543        """
 544        parameters = self._make_parameters(
 545            episode_name,
 546            load_episode,
 547            parameters_update,
 548            parameters_update_second,
 549            load_epoch,
 550            load_search,
 551            load_parameters,
 552            round_to_binary,
 553            purpose,
 554            load_strict=load_strict,
 555        )
 556        if task is None:
 557            task = TaskDispatcher(parameters)
 558        else:
 559            task.update_task(parameters)
 560        self._save_stores(parameters)
 561        return task, parameters
 562
 563    def get_decision_thresholds(
 564        self,
 565        episode_names: List,
 566        metric_name: str = "f1",
 567        parameters_update: Dict = None,
 568        load_epochs: List = None,
 569        remove_saved_features: bool = False,
 570    ) -> Tuple[List, List, TaskDispatcher]:
 571        """Compute optimal decision thresholds or load them if they have been computed before.
 572
 573        Parameters
 574        ----------
 575        episode_names : List
 576            a list of episode names
 577        metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"}
 578            the metric to optimize
 579        parameters_update : dict, optional
 580            the parameter update dictionary
 581        load_epochs : list, optional
 582            a list of epochs to load (by default last are loaded)
 583        remove_saved_features : bool, default False
 584            if `True`, the dataset will be deleted after the computation
 585
 586        Returns
 587        -------
 588        thresholds : list
 589            a list of float decision threshold values
 590        classes : list
 591            the label names corresponding to the values
 592        task : TaskDispatcher | None
 593            the task used in computation
 594
 595        """
 596        parameters = self._make_parameters(
 597            "_",
 598            episode_names[0],
 599            parameters_update,
 600            {},
 601            load_epochs[0],
 602            purpose="prediction",
 603        )
 604        thresholds = self._thresholds().find_thresholds(
 605            episode_names,
 606            load_epochs,
 607            metric_name,
 608            metric_parameters=parameters["metrics"][metric_name],
 609        )
 610        task = None
 611        behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values())
 612        return thresholds, behaviors, task
 613
 614    def run_episode(
 615        self,
 616        episode_name: str,
 617        load_episode: str = None,
 618        parameters_update: Dict = None,
 619        task: TaskDispatcher = None,
 620        load_epoch: int = None,
 621        load_search: str = None,
 622        load_parameters: list = None,
 623        round_to_binary: list = None,
 624        load_strict: bool = True,
 625        n_seeds: int = 1,
 626        force: bool = False,
 627        suppress_name_check: bool = False,
 628        remove_saved_features: bool = False,
 629        mask_name: str = None,
 630        autostop_metric: str = None,
 631        autostop_interval: int = 50,
 632        autostop_threshold: float = 0.001,
 633        loading_bar: bool = False,
 634        trial: Tuple = None,
 635    ) -> TaskDispatcher:
 636        """Run an episode.
 637
 638        The task parameters are read from the config files and then updated with the
 639        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 640        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 641        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 642        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 643        data parameters are used.
 644
 645        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 646        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 647        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 648        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 649
 650        Parameters
 651        ----------
 652        episode_name : str
 653            the episode name
 654        load_episode : str, optional
 655            the (previously run) episode name to load the model from; if the episode has multiple runs,
 656            the new episode will have the same number of runs, each starting with one of the pre-trained models
 657        parameters_update : dict, optional
 658            the dictionary used to update the parameters from the config files
 659        task : TaskDispatcher, optional
 660            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
 661            a new instance)
 662        load_epoch : int, optional
 663            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 664        load_search : str, optional
 665            the hyperparameter search result to load
 666        load_parameters : list, optional
 667            a list of string names of the parameters to load from load_search (if not provided, all parameters
 668            are loaded)
 669        round_to_binary : list, optional
 670            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 671        load_strict : bool, default True
 672            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
 673            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
 674        n_seeds : int, default 1
 675            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 676            `test_episode#0` and `test_episode#1`
 677        force : bool, default False
 678            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
 679        suppress_name_check : bool, default False
 680            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 681            why they are usually forbidden)
 682        remove_saved_features : bool, default False
 683            if `True`, the dataset will be deleted after training
 684        mask_name : str, optional
 685            the name of the real_lens to apply
 686        autostop_metric : str, optional
 687            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 688        autostop_interval : int, default 50
 689            the number of epochs to average the autostop metric over
 690        autostop_threshold : float, default 0.001
 691            the autostop difference threshold
 692        loading_bar : bool, default False
 693            if `True`, a loading bar will be displayed
 694        trial : tuple, optional
 695            a tuple of (trial, metric) for hyperparameter search
 696
 697        Returns
 698        -------
 699        TaskDispatcher
 700            the `TaskDispatcher` object
 701
 702        """
 703
 704        import gc
 705
 706        gc.collect()
 707        if torch.cuda.is_available():
 708            torch.cuda.empty_cache()
 709
 710        if type(n_seeds) is not int or n_seeds < 1:
 711            raise ValueError(
 712                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
 713            )
 714        if n_seeds > 1 and mask_name is not None:
 715            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
 716        self._check_episode_validity(
 717            episode_name, allow_doublecolon=suppress_name_check, force=force
 718        )
 719        load_runs = self._episodes().get_runs(load_episode)
 720        if len(load_runs) > 1:
 721            task = self.run_episodes(
 722                episode_names=[
 723                    f'{episode_name}#{run.split("#")[-1]}' for run in load_runs
 724                ],
 725                load_episodes=load_runs,
 726                parameters_updates=[parameters_update for _ in load_runs],
 727                load_epochs=[load_epoch for _ in load_runs],
 728                load_searches=[load_search for _ in load_runs],
 729                load_parameters=[load_parameters for _ in load_runs],
 730                round_to_binary=[round_to_binary for _ in load_runs],
 731                load_strict=[load_strict for _ in load_runs],
 732                suppress_name_check=True,
 733                force=force,
 734                remove_saved_features=False,
 735            )
 736            if remove_saved_features:
 737                self._remove_stores(
 738                    {
 739                        "general": task.general_parameters,
 740                        "data": task.data_parameters,
 741                        "features": task.feature_parameters,
 742                    }
 743                )
 744            if n_seeds > 1:
 745                warnings.warn(
 746                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
 747                )
 748        elif n_seeds > 1:
 749
 750            self.run_episodes(
 751                episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)],
 752                load_episodes=[load_episode for _ in range(n_seeds)],
 753                parameters_updates=[parameters_update for _ in range(n_seeds)],
 754                load_epochs=[load_epoch for _ in range(n_seeds)],
 755                load_searches=[load_search for _ in range(n_seeds)],
 756                load_parameters=[load_parameters for _ in range(n_seeds)],
 757                round_to_binary=[round_to_binary for _ in range(n_seeds)],
 758                load_strict=[load_strict for _ in range(n_seeds)],
 759                suppress_name_check=True,
 760                force=force,
 761                remove_saved_features=remove_saved_features,
 762            )
 763        else:
 764            print(f"TRAINING {episode_name}")
 765            try:
 766                task, parameters = self._make_task_training(
 767                    episode_name,
 768                    load_episode,
 769                    parameters_update,
 770                    load_epoch,
 771                    load_search,
 772                    load_parameters,
 773                    round_to_binary,
 774                    continuing=False,
 775                    task=task,
 776                    mask_name=mask_name,
 777                    load_strict=load_strict,
 778                )
 779                self._save_episode(
 780                    episode_name,
 781                    parameters,
 782                    task.behaviors_dict(),
 783                    norm_stats=task.get_normalization_stats(),
 784                )
 785                time_start = time.time()
 786                if trial is not None:
 787                    trial, metric = trial
 788                else:
 789                    trial, metric = None, None
 790                logs = task.train(
 791                    autostop_metric=autostop_metric,
 792                    autostop_interval=autostop_interval,
 793                    autostop_threshold=autostop_threshold,
 794                    loading_bar=loading_bar,
 795                    trial=trial,
 796                    optimized_metric=metric,
 797                )
 798                time_end = time.time()
 799                time_total = time_end - time_start
 800                hours = int(time_total // 3600)
 801                time_total -= hours * 3600
 802                minutes = int(time_total // 60)
 803                time_total -= minutes * 60
 804                seconds = int(time_total)
 805                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 806                self._update_episode_results(episode_name, logs, training_time)
 807                if remove_saved_features:
 808                    self._remove_stores(parameters)
 809                print("\n")
 810                return task
 811
 812            except Exception as e:
 813                if isinstance(e, optuna.exceptions.TrialPruned):
 814                    raise e
 815                else:
 816                    # if str(e) != f"The {episode_name} episode name is already in use!":
 817                    #     self.remove_episode(episode_name)
 818                    raise RuntimeError(f"Episode {episode_name} could not run")
 819
 820    def run_episodes(
 821        self,
 822        episode_names: List,
 823        load_episodes: List = None,
 824        parameters_updates: List = None,
 825        load_epochs: List = None,
 826        load_searches: List = None,
 827        load_parameters: List = None,
 828        round_to_binary: List = None,
 829        load_strict: List = None,
 830        force: bool = False,
 831        suppress_name_check: bool = False,
 832        remove_saved_features: bool = False,
 833    ) -> TaskDispatcher:
 834        """Run multiple episodes in sequence (and re-use previously loaded information).
 835
 836        For each episode, the task parameters are read from the config files and then updated with the
 837        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
 838        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 839        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 840        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 841        data parameters are used.
 842
 843        Parameters
 844        ----------
 845        episode_names : list
 846            a list of strings of episode names
 847        load_episodes : list, optional
 848            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
 849            the new episode will have the same number of runs, each starting with one of the pre-trained models
 850        parameters_updates : list, optional
 851            a list of dictionaries used to update the parameters from the config
 852        load_epochs : list, optional
 853            a list of integers used to specify the epoch to load (if load_episodes is not None)
 854        load_searches : list, optional
 855            a list of strings of hyperparameter search results to load
 856        load_parameters : list, optional
 857            a list of lists of string names of the parameters to load from the searches
 858        round_to_binary : list, optional
 859            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 860        load_strict : list, optional
 861            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
 862            the corresponding episode and differences in parameter name lists and
 863            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
 864            every episode)
 865        force : bool, default False
 866            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
 867        suppress_name_check : bool, default False
 868            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 869            why they are usually forbidden)
 870        remove_saved_features : bool, default False
 871            if `True`, the dataset will be deleted after training
 872
 873        Returns
 874        -------
 875        TaskDispatcher
 876            the task dispatcher object
 877
 878        """
 879        task = None
 880        if load_searches is None:
 881            load_searches = [None for _ in episode_names]
 882        if load_episodes is None:
 883            load_episodes = [None for _ in episode_names]
 884        if parameters_updates is None:
 885            parameters_updates = [None for _ in episode_names]
 886        if load_parameters is None:
 887            load_parameters = [None for _ in episode_names]
 888        if load_epochs is None:
 889            load_epochs = [None for _ in episode_names]
 890        if load_strict is None:
 891            load_strict = [True for _ in episode_names]
 892        for (
 893            parameters_update,
 894            episode_name,
 895            load_episode,
 896            load_epoch,
 897            load_search,
 898            load_parameters_list,
 899            load_strict_value,
 900        ) in zip(
 901            parameters_updates,
 902            episode_names,
 903            load_episodes,
 904            load_epochs,
 905            load_searches,
 906            load_parameters,
 907            load_strict,
 908        ):
 909            task = self.run_episode(
 910                episode_name,
 911                load_episode,
 912                parameters_update,
 913                task,
 914                load_epoch,
 915                load_search,
 916                load_parameters_list,
 917                round_to_binary,
 918                load_strict_value,
 919                suppress_name_check=suppress_name_check,
 920                force=force,
 921                remove_saved_features=remove_saved_features,
 922            )
 923        return task
 924
 925    def continue_episode(
 926        self,
 927        episode_name: str,
 928        num_epochs: int = None,
 929        task: TaskDispatcher = None,
 930        n_seeds: int = 1,
 931        remove_saved_features: bool = False,
 932        device: str = "cuda",
 933        num_cpus: int = None,
 934    ) -> TaskDispatcher:
 935        """Load an older episode and continue running from the latest checkpoint.
 936
 937        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
 938
 939        Parameters
 940        ----------
 941        episode_name : str
 942            the name of the episode to continue
 943        num_epochs : int, optional
 944            the new number of epochs
 945        task : TaskDispatcher, optional
 946            a pre-existing task; if provided, the method will update the task instead of creating a new one
 947            (this might save time, mainly on dataset loading)
 948        n_seeds : int, default 1
 949            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 950            `test_episode#0` and `test_episode#1`
 951        remove_saved_features : bool, default False
 952            if `True`, pre-computed features will be deleted after the run
 953        device : str, default "cuda"
 954            the torch device to use
 955        num_cpus : int, optional
 956            the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used
 957
 958        Returns
 959        -------
 960        TaskDispatcher
 961            the task dispatcher
 962
 963        """
 964        runs = self._episodes().get_runs(episode_name)
 965        for run in runs:
 966            print(f"TRAINING {run}")
 967            if num_epochs is None and not self._episode(run).unfinished():
 968                continue
 969            parameters_update = {
 970                "training": {
 971                    "num_epochs": num_epochs,
 972                    "device": device,
 973                },
 974                "general": {"num_cpus": num_cpus},
 975            }
 976            task, parameters = self._make_task_training(
 977                run,
 978                load_episode=run,
 979                parameters_update=parameters_update,
 980                continuing=True,
 981                task=task,
 982            )
 983            time_start = time.time()
 984            logs = task.train()
 985            time_end = time.time()
 986            old_time = self._training_time(run)
 987            if not np.isnan(old_time):
 988                time_end += old_time
 989                time_total = time_end - time_start
 990                hours = int(time_total // 3600)
 991                time_total -= hours * 3600
 992                minutes = int(time_total // 60)
 993                time_total -= minutes * 60
 994                seconds = int(time_total)
 995                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 996            else:
 997                training_time = np.nan
 998            self._save_episode(
 999                run,
1000                parameters,
1001                task.behaviors_dict(),
1002                suppress_validation=True,
1003                training_time=training_time,
1004                norm_stats=task.get_normalization_stats(),
1005            )
1006            self._update_episode_results(run, logs)
1007            print("\n")
1008        if len(runs) < n_seeds:
1009            for i in range(len(runs), n_seeds):
1010                self.run_episode(
1011                    f"{episode_name}#{i}",
1012                    parameters_update=self._episodes().load_parameters(runs[0]),
1013                    task=task,
1014                    suppress_name_check=True,
1015                )
1016        if remove_saved_features:
1017            self._remove_stores(parameters)
1018        return task
1019
1020    def run_default_hyperparameter_search(
1021        self,
1022        search_name: str,
1023        model_name: str,
1024        metric: str = "f1",
1025        best_n: int = 3,
1026        direction: str = "maximize",
1027        load_episode: str = None,
1028        load_epoch: int = None,
1029        load_strict: bool = True,
1030        prune: bool = True,
1031        force: bool = False,
1032        remove_saved_features: bool = False,
1033        overlap: float = 0,
1034        num_epochs: int = 50,
1035        test_frac: float = None,
1036        n_trials=150,
1037        batch_size=32,
1038    ):
1039        """Run an optuna hyperparameter search with default parameters for a model.
1040
1041        For the vast majority of cases, optimizing the default parameters should be enough.
1042        Check out `dlc2action.options.model_hyperparameters` for the lists of parameters.
1043        There are also options to set overlap, test fraction and number of epochs parameters for the search without
1044        modifying the project config files. However, if you want something more complex, look into
1045        `Project.run_hyperparameter_search`.
1046
1047        The task parameters are read from the config files and updated with the parameters_update dictionary.
1048        The model can be either initialized from scratch or loaded from a previously run episode.
1049        For each trial, the objective metric is averaged over a few best epochs.
1050
1051        Parameters
1052        ----------
1053        search_name : str
1054            the name of the search to store it in the meta files and load in run_episode
1055        model_name : str
1056            the name
1057        metric : str
1058            the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to
1059            `"none"` in the config files, it will be reset to `"macro"` for the search
1060        best_n : int, default 1
1061            the number of epochs to average the metric; if 0, the last value is taken
1062        direction : {'maximize', 'minimize'}
1063            optimization direction
1064        load_episode : str, optional
1065            the name of the episode to load the model from
1066        load_epoch : int, optional
1067            the epoch to load the model from (if not provided, the last checkpoint is used)
1068        load_strict : bool, default True
1069            if `True`, the model will be loaded only if the parameters match exactly
1070        prune : bool, default False
1071            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1072            (with optuna HyperBand pruner)
1073        force : bool, default False
1074            if `True`, existing searches with the same name will be overwritten
1075        remove_saved_features : bool, default False
1076            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1077        overlap : float, default 0
1078            the overlap to use for the search
1079        num_epochs : int, default 50
1080            the number of epochs to use for the search
1081        test_frac : float, optional
1082            the test fraction to use for the search
1083        n_trials : int, default 150
1084            the number of trials to run
1085        batch_size : int, default 32
1086            the batch size to use for the search
1087
1088        Returns
1089        -------
1090        best_parameters : dict
1091            a dictionary of best parameters
1092
1093        """
1094        if model_name not in options.model_hyperparameters:
1095            raise ValueError(
1096                f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()"
1097            )
1098        pars = {
1099            "general": {"overlap": overlap, "model_name": model_name},
1100            "training": {"num_epochs": num_epochs, "batch_size": batch_size},
1101        }
1102        if test_frac is not None:
1103            pars["training"]["test_frac"] = test_frac
1104        if not metric.split("_")[-1].isnumeric():
1105            project_pars = self._read_parameters()
1106            if project_pars["metrics"][metric].get("average") == "none":
1107                pars["metrics"] = {metric: {"average": "macro"}}
1108        return self.run_hyperparameter_search(
1109            search_name=search_name,
1110            search_space=options.model_hyperparameters[model_name],
1111            metric=metric,
1112            n_trials=n_trials,
1113            best_n=best_n,
1114            parameters_update=pars,
1115            direction=direction,
1116            load_episode=load_episode,
1117            load_epoch=load_epoch,
1118            load_strict=load_strict,
1119            prune=prune,
1120            force=force,
1121            remove_saved_features=remove_saved_features,
1122        )
1123
1124    def run_hyperparameter_search(
1125        self,
1126        search_name: str,
1127        search_space: Dict,
1128        metric: str = "f1",
1129        n_trials: int = 20,
1130        best_n: int = 1,
1131        parameters_update: Dict = None,
1132        direction: str = "maximize",
1133        load_episode: str = None,
1134        load_epoch: int = None,
1135        load_strict: bool = True,
1136        prune: bool = False,
1137        force: bool = False,
1138        remove_saved_features: bool = False,
1139        make_plots: bool = True,
1140    ) -> Dict:
1141        """Run an optuna hyperparameter search.
1142
1143        For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`.
1144
1145        To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is
1146        a dictionary where keys are model names and values are default search spaces.
1147
1148        The task parameters are read from the config files and updated with the parameters_update dictionary.
1149        The model can be either initialized from scratch or loaded from a previously run episode.
1150        For each trial, the objective metric is averaged over a few best epochs.
1151
1152        Parameters
1153        ----------
1154        search_name : str
1155            the name of the search to store it in the meta files and load in run_episode
1156        search_space : dict
1157            a dictionary representing the search space; of this general structure:
1158            {'group/param_name': ('float/int/float_log/int_log', start, end),
1159            'group/param_name': ('categorical', [choices])}, e.g.
1160            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
1161            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])};
1162        metric : str, default f1
1163            the metric to maximize/minimize (see direction)
1164        n_trials : int, default 20
1165            the number of optimization trials to run
1166        best_n : int, default 1
1167            the number of epochs to average the metric; if 0, the last value is taken
1168        parameters_update : dict, optional
1169            the parameters update dictionary
1170        direction : {'maximize', 'minimize'}
1171            optimization direction
1172        load_episode : str, optional
1173            the name of the episode to load the model from
1174        load_epoch : int, optional
1175            the epoch to load the model from (if not provided, the last checkpoint is used)
1176        load_strict : bool, default True
1177            if `True`, the model will be loaded only if the parameters match exactly
1178        prune : bool, default False
1179            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1180            (with optuna HyperBand pruner)
1181        force : bool, default False
1182            if `True`, existing searches with the same name will be overwritten
1183        remove_saved_features : bool, default False
1184            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1185
1186        Returns
1187        -------
1188        dict
1189            a dictionary of best parameters
1190
1191        """
1192        self._check_search_validity(search_name, force=force)
1193        print(f"SEARCH {search_name}")
1194        self.remove_episode(f"_{search_name}")
1195        if parameters_update is None:
1196            parameters_update = {}
1197        parameters_update = self._update(
1198            parameters_update, {"general": {"metric_functions": {metric}}}
1199        )
1200        parameters = self._make_parameters(
1201            f"_{search_name}",
1202            load_episode,
1203            parameters_update,
1204            parameters_update_second={"training": {"model_save_path": None}},
1205            load_epoch=load_epoch,
1206            load_strict=load_strict,
1207        )
1208        task = None
1209
1210        if prune:
1211            pruner = optuna.pruners.HyperbandPruner()
1212        else:
1213            pruner = optuna.pruners.NopPruner()
1214        study = optuna.create_study(direction=direction, pruner=pruner)
1215        runner = _Runner(
1216            search_space=search_space,
1217            load_episode=load_episode,
1218            load_epoch=load_epoch,
1219            metric=metric,
1220            average=best_n,
1221            task=task,
1222            remove_saved_features=remove_saved_features,
1223            project=self,
1224            search_name=search_name,
1225        )
1226        study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials)
1227        if make_plots:
1228            search_path = self._search_path(search_name)
1229            os.mkdir(search_path)
1230            fig = optuna.visualization.plot_contour(study)
1231            plotly.offline.plot(
1232                fig, filename=os.path.join(search_path, f"{search_name}_contour.html")
1233            )
1234            fig = optuna.visualization.plot_param_importances(study)
1235            plotly.offline.plot(
1236                fig,
1237                filename=os.path.join(search_path, f"{search_name}_importances.html"),
1238            )
1239        best_params = study.best_params
1240        best_value = study.best_value
1241        if best_value == 0 or best_value == float("inf"):
1242            raise ValueError(
1243                f"Best metric value is {best_value}, check your partition method and make sure that all behaviors are present in the validation set!"
1244            )
1245        self._save_search(
1246            search_name,
1247            parameters,
1248            n_trials,
1249            best_params,
1250            best_value,
1251            metric,
1252            search_space,
1253        )
1254        self.remove_episode(f"_{search_name}")
1255        runner.clean()
1256        print(f"best parameters: {best_params}")
1257        print("\n")
1258        return best_params
1259
1260    def run_prediction(
1261        self,
1262        prediction_name: str,
1263        episode_names: List,
1264        load_epochs: List = None,
1265        parameters_update: Dict = None,
1266        augment_n: int = 10,
1267        data_path: str = None,
1268        mode: str = "all",
1269        file_paths: Set = None,
1270        remove_saved_features: bool = False,
1271        frame_number_map_file: str = None,
1272        force: bool = False,
1273        embedding: bool = False,
1274    ) -> None:
1275        """Load models from previously run episodes to generate a prediction.
1276
1277        The probabilities predicted by the models are averaged.
1278        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1279        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1280        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1281        are the prediction arrays.
1282
1283        Parameters
1284        ----------
1285        prediction_name : str
1286            the name of the prediction
1287        episode_names : list
1288            a list of string episode names to load the models from
1289        load_epochs : list or int, optional
1290            a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes
1291        parameters_update : dict, optional
1292            a dictionary of parameter updates
1293        augment_n : int, default 10
1294            the number of augmentations to average over
1295        data_path : str, optional
1296            the data path to run the prediction for
1297        mode : {'all', 'test', 'val', 'train'}
1298            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1299        file_paths : set, optional
1300            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1301            for
1302        remove_saved_features : bool, default False
1303            if `True`, pre-computed features will be deleted
1304        submission : bool, default False
1305            if `True`, a MABe-22 style submission file is generated
1306        frame_number_map_file : str, optional
1307            path to the frame number map file
1308        force : bool, default False
1309            if `True`, existing prediction with this name will be overwritten
1310        embedding : bool, default False
1311            if `True`, the prediction is made for the embedding task
1312
1313        """
1314        self._check_prediction_validity(prediction_name, force=force)
1315        print(f"PREDICTION {prediction_name}")
1316        task, parameters, mode, prediction, inference_time, behavior_dict = (
1317            self._make_prediction(
1318                prediction_name,
1319                episode_names,
1320                load_epochs,
1321                parameters_update,
1322                data_path,
1323                file_paths,
1324                mode,
1325                augment_n,
1326                evaluate=False,
1327                embedding=embedding,
1328            )
1329        )
1330        predicted = task.dataset(mode).generate_full_length_prediction(prediction)
1331
1332        if remove_saved_features:
1333            self._remove_stores(parameters)
1334
1335        self._save_prediction(
1336            prediction_name,
1337            predicted,
1338            parameters,
1339            task,
1340            mode,
1341            embedding,
1342            inference_time,
1343            behavior_dict,
1344        )
1345        print("\n")
1346
1347    def evaluate_prediction(
1348        self,
1349        prediction_name: str,
1350        parameters_update: Dict = None,
1351        data_path: str = None,
1352        annotation_path: str = None,
1353        file_paths: Set = None,
1354        mode: str = None,
1355        remove_saved_features: bool = False,
1356        annotation_type: str = "none",
1357        num_classes: int = None,  # Set when using data_path
1358    ) -> Tuple[float, dict]:
1359        """Make predictions and evaluate them
1360        inputs:
1361            prediction_name (str): the name of the prediction
1362            parameters_update (dict): a dictionary of parameter updates
1363            data_path (str): the data path to run the prediction for
1364            annotation_path (str): the annotation path to run the prediction for
1365            file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for
1366            mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1367            remove_saved_features (bool): if `True`, pre-computed features will be deleted
1368            annotation_type (str): the type of annotation to use for evaluation
1369            num_classes (int): the number of classes in the dataset, must be set with data_path
1370        outputs:
1371            results (dict): a dictionary of average values of metric functions
1372        """
1373
1374        prediction_path = os.path.join(
1375            self.project_path, "results", "predictions", f"{prediction_name}"
1376        )
1377        prediction_dict = {}
1378        for prediction_file_path in [
1379            os.path.join(prediction_path, i) for i in os.listdir(prediction_path)
1380        ]:
1381            with open(os.path.join(prediction_file_path), "rb") as f:
1382                prediction = pickle.load(f)
1383            video_id = os.path.basename(prediction_file_path).split(
1384                "_" + prediction_name
1385            )[0]
1386            prediction_dict[video_id] = prediction
1387        if parameters_update is None:
1388            parameters_update = {}
1389        parameters_update = self._update(
1390            self._predictions().load_parameters(prediction_name), parameters_update
1391        )
1392        parameters_update.pop("model")
1393        if not data_path is None:
1394            assert (
1395                not num_classes is None
1396            ), "num_classes must be provided if data_path is provided"
1397            parameters_update["general"]["num_classes"] = num_classes + int(
1398                parameters_update["general"]["exclusive"]
1399            )
1400        task, parameters, mode = self._make_task_prediction(
1401            "_",
1402            load_episode=None,
1403            parameters_update=parameters_update,
1404            data_path=data_path,
1405            annotation_path=annotation_path,
1406            file_paths=file_paths,
1407            mode=mode,
1408            annotation_type=annotation_type,
1409        )
1410        results = task.evaluate_prediction(prediction_dict, data=mode)
1411        if remove_saved_features:
1412            self._remove_stores(parameters)
1413        results = Project._reformat_results(
1414            results[1],
1415            task.behaviors_dict(),
1416            exclusive=task.general_parameters["exclusive"],
1417        )
1418        return results
1419
1420    def evaluate(
1421        self,
1422        episode_names: List,
1423        load_epochs: List = None,
1424        augment_n: int = 0,
1425        data_path: str = None,
1426        file_paths: Set = None,
1427        mode: str = None,
1428        parameters_update: Dict = None,
1429        multiple_episode_policy: str = "average",
1430        remove_saved_features: bool = False,
1431        skip_updating_meta: bool = True,
1432        annotation_type: str = "none",
1433    ) -> Dict:
1434        """Load one or several models from previously run episodes to make an evaluation.
1435
1436        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1437
1438        Parameters
1439        ----------
1440        episode_names : list
1441            a list of string episode names to load the models from
1442        load_epochs : list, optional
1443            a list of integer epoch indices to load the model from; if None, the last ones are used
1444        augment_n : int, default 0
1445            the number of augmentations to average over
1446        data_path : str, optional
1447            the data path to run the prediction for
1448        file_paths : set, optional
1449            a set of files to run the prediction for
1450        mode : {'test', 'val', 'train', 'all'}
1451            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1452            by default 'test' if test subset is not empty and 'val' otherwise)
1453        parameters_update : dict, optional
1454            a dictionary with parameter updates (cannot change model parameters)
1455        multiple_episode_policy : {'average', 'statistics'}
1456            the policy to use when multiple episodes are provided
1457        remove_saved_features : bool, default False
1458            if `True`, the dataset will be deleted
1459        skip_updating_meta : bool, default True
1460            if `True`, the meta file will not be updated with the computed metrics
1461
1462        Returns
1463        -------
1464        metric : dict
1465            a dictionary of average values of metric functions
1466
1467        """
1468        names = []
1469        for episode_name in episode_names:
1470            names += self._episodes().get_runs(episode_name)
1471        if len(set(episode_names)) == 1:
1472            print(f"EVALUATION {episode_names[0]}")
1473        else:
1474            print(f"EVALUATION {episode_names}")
1475        if len(names) > 1:
1476            evaluate = True
1477        else:
1478            evaluate = False
1479        if multiple_episode_policy == "average":
1480            task, parameters, mode, prediction, inference_time, behavior_dict = (
1481                self._make_prediction(
1482                    "_",
1483                    episode_names,
1484                    load_epochs,
1485                    parameters_update,
1486                    mode=mode,
1487                    data_path=data_path,
1488                    file_paths=file_paths,
1489                    augment_n=augment_n,
1490                    evaluate=evaluate,
1491                    annotation_type=annotation_type,
1492                )
1493            )
1494            print("EVALUATE PREDICTION:")
1495            indices = [
1496                list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict))
1497            ]
1498            _, results = task.evaluate_prediction(
1499                prediction, data=mode, indices=indices
1500            )
1501            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1502                self._update_episode_metrics(names[0], results)
1503            results = Project._reformat_results(
1504                results,
1505                behavior_dict,
1506                exclusive=task.general_parameters["exclusive"],
1507            )
1508
1509        elif multiple_episode_policy == "statistics":
1510            values = defaultdict(lambda: [])
1511            task = None
1512            for name in names:
1513                (
1514                    task,
1515                    parameters,
1516                    mode,
1517                    prediction,
1518                    inference_time,
1519                    behavior_dict,
1520                ) = self._make_prediction(
1521                    "_",
1522                    [name],
1523                    load_epochs,
1524                    parameters_update,
1525                    mode=mode,
1526                    data_path=data_path,
1527                    file_paths=file_paths,
1528                    augment_n=augment_n,
1529                    evaluate=evaluate,
1530                    task=task,
1531                )
1532                _, metrics = task.evaluate_prediction(
1533                    prediction, data=mode, indices=list(behavior_dict.keys())
1534                )
1535                for name, value in metrics.items():
1536                    values[name].append(value)
1537                if mode == "val" and not skip_updating_meta:
1538                    self._update_episode_metrics(name, metrics)
1539            results = defaultdict(lambda: {})
1540            mean_string = ""
1541            std_string = ""
1542            for key, value_list in values.items():
1543                results[key]["mean"] = np.mean(value_list)
1544                results[key]["std"] = np.std(value_list)
1545                results[key]["all"] = value_list
1546                mean_string += f"{key} {np.mean(value_list):.3f}, "
1547                std_string += f"{key} {np.std(value_list):.3f}, "
1548            print("MEAN:")
1549            print(mean_string)
1550            print("STD:")
1551            print(std_string)
1552        else:
1553            raise ValueError(
1554                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1555                f"from ['average', 'statistics']"
1556            )
1557        if len(names) > 0 and remove_saved_features:
1558            self._remove_stores(parameters)
1559        print(f"Inference time: {inference_time}")
1560        print("\n")
1561        return results
1562
1563    def run_suggestion(
1564        self,
1565        suggestions_name: str,
1566        error_episode: str = None,
1567        error_load_epoch: int = None,
1568        error_class: str = None,
1569        suggestions_prediction: str = None,
1570        suggestion_episodes: List = [None],
1571        suggestion_load_epoch: int = None,
1572        suggestion_classes: List = None,
1573        error_threshold: float = 0.5,
1574        error_threshold_diff: float = 0.1,
1575        error_hysteresis: bool = False,
1576        suggestion_threshold: Union[float, List] = 0.5,
1577        suggestion_threshold_diff: Union[float, List] = 0.1,
1578        suggestion_hysteresis: Union[bool, List] = True,
1579        min_frames_suggestion: int = 10,
1580        min_frames_al: int = 30,
1581        visibility_min_score: float = 0,
1582        visibility_min_frac: float = 0.7,
1583        augment_n: int = 0,
1584        exclude_classes: List = None,
1585        exclude_threshold: Union[float, List] = 0.6,
1586        exclude_threshold_diff: Union[float, List] = 0.1,
1587        exclude_hysteresis: Union[bool, List] = False,
1588        include_classes: List = None,
1589        include_threshold: Union[float, List] = 0.4,
1590        include_threshold_diff: Union[float, List] = 0.1,
1591        include_hysteresis: Union[bool, List] = False,
1592        data_path: str = None,
1593        file_paths: Set = None,
1594        parameters_update: Dict = None,
1595        mode: str = "all",
1596        force: bool = False,
1597        remove_saved_features: bool = False,
1598        cut_annotated: bool = False,
1599        background_threshold: float = None,
1600    ) -> None:
1601        """Create active learning and suggestion files.
1602
1603        Generate predictions with the error and suggestion model and use them to create
1604        suggestion files for the labeling interface. Those files will render as suggested labels
1605        at intervals with high pose estimation quality. Quality here is defined by probability of error
1606        (predicted by the error model) and visibility parameters.
1607
1608        If `error_episode` or `exclude_classes` is not `None`,
1609        an active learning file will be created as well (with frames with high predicted probability of classes
1610        from `exclude_classes` and/or errors excluded from the active learning intervals).
1611
1612        In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals)
1613        you can apply one of three methods.
1614
1615        - **Simple threshold**
1616
1617            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold`
1618            parameter to $\alpha$.
1619            In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will
1620            be considered labeled.
1621
1622        - **Hysteresis threshold**
1623
1624            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1625            parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$.
1626            Now intervals will be marked with a label if the probability of that label for all frames is higher
1627            than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$.
1628
1629        - **Max hysteresis threshold**
1630
1631            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1632            parameter to $\alpha$ and the `threshold_diff` parameter to `None`.
1633            With this combination intervals are marked with a label if that label is more likely than any other
1634            for all frames in this interval and at for at least one of those frames its probability is higher than
1635            $\alpha$.
1636
1637        Parameters
1638        ----------
1639        suggestions_name : str
1640            the name of the suggestions
1641        error_episode : str, optional
1642            the name of the episode where the error model should be loaded from
1643        error_load_epoch : int, optional
1644            the epoch the error model should be loaded from
1645        error_class : str, optional
1646            the name of the error class (in `error_episode`)
1647        suggestions_prediction : str, optional
1648            the name of the predictions that should be used for the suggestion model
1649        suggestion_episodes : list, optional
1650            the names of the episodes where the suggestion models should be loaded from
1651        suggestion_load_epoch : int, optional
1652            the epoch the suggestion model should be loaded from
1653        suggestion_classes : list, optional
1654            a list of string names of the classes that should be suggested (in `suggestion_episode`)
1655        error_threshold : float, default 0.5
1656            the hard threshold for error prediction
1657        error_threshold_diff : float, default 0.1
1658            the difference between soft and hard thresholds for error prediction (in case hysteresis is used)
1659        error_hysteresis : bool, default False
1660            if True, hysteresis is used for error prediction
1661        suggestion_threshold : float | list, default 0.5
1662            the hard threshold for class prediction (use a list to set different rules for different classes)
1663        suggestion_threshold_diff : float | list, default 0.1
1664            the difference between soft and hard thresholds for class prediction (in case hysteresis is used;
1665            use a list to set different rules for different classes)
1666        suggestion_hysteresis : bool | list, default True
1667            if True, hysteresis is used for class prediction (use a list to set different rules for different classes)
1668        min_frames_suggestion : int, default 10
1669            only actions longer than this number of frames will be suggested
1670        min_frames_al : int, default 30
1671            only active learning intervals longer than this number of frames will be suggested
1672        visibility_min_score : float, default 0
1673            the minimum visibility score for visibility filtering
1674        visibility_min_frac : float, default 0.7
1675            the minimum fraction of visible frames for visibility filtering
1676        augment_n : int, default 10
1677            the number of augmentations to average the predictions over
1678        exclude_classes : list, optional
1679            a list of string names of classes that should be excluded from the active learning intervals
1680        exclude_threshold : float | list, default 0.6
1681            the hard threshold for excluded class prediction (use a list to set different rules for different classes)
1682        exclude_threshold_diff : float | list, default 0.1
1683            the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used)
1684        exclude_hysteresis : bool | list, default False
1685            if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes)
1686        include_classes : list, optional
1687            a list of string names of classes that should be included into the active learning intervals
1688        include_threshold : float | list, default 0.6
1689            the hard threshold for included class prediction (use a list to set different rules for different classes)
1690        include_threshold_diff : float | list, default 0.1
1691            the difference between soft and hard thresholds for included class prediction (in case hysteresis is used)
1692        include_hysteresis : bool | list, default False
1693            if True, hysteresis is used for included class prediction (use a list to set different rules for different classes)
1694        data_path : str, optional
1695            the data path to run the prediction for
1696        file_paths : set, optional
1697            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1698            for
1699        parameters_update : dict, optional
1700            the parameters update dictionary
1701        mode : {'all', 'test', 'val', 'train'}
1702            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1703        force : bool, default False
1704            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
1705        remove_saved_features : bool, default False
1706            if `True`, the dataset will be deleted.
1707        cut_annotated : bool, default False
1708            if `True`, annotated frames will be cut from the suggestions
1709        background_threshold : float, default 0.5
1710            the threshold for background prediction
1711
1712        """
1713        self._check_suggestions_validity(suggestions_name, force=force)
1714        if any([x is None for x in suggestion_episodes]):
1715            suggestion_episodes = None
1716        if error_episode is None and (
1717            suggestion_episodes is None and suggestions_prediction is None
1718        ):
1719            raise ValueError(
1720                "Both error_episode and suggestion_episode parameters cannot be None at the same time"
1721            )
1722        print(f"SUGGESTION {suggestions_name}")
1723        task = None
1724        if suggestion_classes is None:
1725            suggestion_classes = []
1726        if exclude_classes is None:
1727            exclude_classes = []
1728        if include_classes is None:
1729            include_classes = []
1730        if isinstance(suggestion_threshold, list):
1731            if len(suggestion_threshold) != len(suggestion_classes):
1732                raise ValueError(
1733                    "The suggestion_threshold parameter has to be either a float value or a list of "
1734                    f"float values of the same length as suggestion_classes (got a list of length "
1735                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1736                )
1737        else:
1738            suggestion_threshold = [suggestion_threshold for _ in suggestion_classes]
1739        if isinstance(suggestion_threshold_diff, list):
1740            if len(suggestion_threshold_diff) != len(suggestion_classes):
1741                raise ValueError(
1742                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1743                    f"float values of the same length as suggestion_classes (got a list of length "
1744                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1745                )
1746        else:
1747            suggestion_threshold_diff = [
1748                suggestion_threshold_diff for _ in suggestion_classes
1749            ]
1750        if isinstance(suggestion_hysteresis, list):
1751            if len(suggestion_hysteresis) != len(suggestion_classes):
1752                raise ValueError(
1753                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1754                    f"float values of the same length as suggestion_classes (got a list of length "
1755                    f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)"
1756                )
1757        else:
1758            suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes]
1759        if isinstance(exclude_threshold, list):
1760            if len(exclude_threshold) != len(exclude_classes):
1761                raise ValueError(
1762                    "The exclude_threshold parameter has to be either a float value or a list of "
1763                    f"float values of the same length as exclude_classes (got a list of length "
1764                    f"{len(exclude_threshold)} for {len(exclude_classes)} classes)"
1765                )
1766        else:
1767            exclude_threshold = [exclude_threshold for _ in exclude_classes]
1768        if isinstance(exclude_threshold_diff, list):
1769            if len(exclude_threshold_diff) != len(exclude_classes):
1770                raise ValueError(
1771                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1772                    f"float values of the same length as exclude_classes (got a list of length "
1773                    f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)"
1774                )
1775        else:
1776            exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes]
1777        if isinstance(exclude_hysteresis, list):
1778            if len(exclude_hysteresis) != len(exclude_classes):
1779                raise ValueError(
1780                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1781                    f"float values of the same length as suggestion_classes (got a list of length "
1782                    f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)"
1783                )
1784        else:
1785            exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes]
1786        if isinstance(include_threshold, list):
1787            if len(include_threshold) != len(include_classes):
1788                raise ValueError(
1789                    "The exclude_threshold parameter has to be either a float value or a list of "
1790                    f"float values of the same length as exclude_classes (got a list of length "
1791                    f"{len(include_threshold)} for {len(include_classes)} classes)"
1792                )
1793        else:
1794            include_threshold = [include_threshold for _ in include_classes]
1795        if isinstance(include_threshold_diff, list):
1796            if len(include_threshold_diff) != len(include_classes):
1797                raise ValueError(
1798                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1799                    f"float values of the same length as exclude_classes (got a list of length "
1800                    f"{len(include_threshold_diff)} for {len(include_classes)} classes)"
1801                )
1802        else:
1803            include_threshold_diff = [include_threshold_diff for _ in include_classes]
1804        if isinstance(include_hysteresis, list):
1805            if len(include_hysteresis) != len(include_classes):
1806                raise ValueError(
1807                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1808                    f"float values of the same length as suggestion_classes (got a list of length "
1809                    f"{len(include_hysteresis)} for {len(include_classes)} classes)"
1810                )
1811        else:
1812            include_hysteresis = [include_hysteresis for _ in include_classes]
1813        if (suggestion_episodes is None and suggestions_prediction is None) and len(
1814            exclude_classes
1815        ) > 0:
1816            raise ValueError(
1817                "In order to exclude classes from the active learning intervals you need to set the "
1818                "suggestion_episode parameter"
1819            )
1820
1821        task = None
1822        if error_episode is not None:
1823            task, parameters, mode = self._make_task_prediction(
1824                prediction_name=suggestions_name,
1825                load_episode=error_episode,
1826                parameters_update=parameters_update,
1827                load_epoch=error_load_epoch,
1828                data_path=data_path,
1829                mode=mode,
1830                file_paths=file_paths,
1831                task=task,
1832            )
1833            predicted_error = task.predict(
1834                data=mode,
1835                raw_output=True,
1836                apply_primary_function=True,
1837                augment_n=augment_n,
1838            )
1839        else:
1840            predicted_error = None
1841
1842        if suggestion_episodes is not None:
1843            (
1844                task,
1845                parameters,
1846                mode,
1847                predicted_classes,
1848                inference_time,
1849                behavior_dict,
1850            ) = self._make_prediction(
1851                prediction_name=suggestions_name,
1852                episode_names=suggestion_episodes,
1853                load_epochs=suggestion_load_epoch,
1854                parameters_update=parameters_update,
1855                data_path=data_path,
1856                file_paths=file_paths,
1857                mode=mode,
1858                task=task,
1859            )
1860        elif suggestions_prediction is not None:
1861            with open(
1862                os.path.join(
1863                    self.project_path,
1864                    "results",
1865                    "predictions",
1866                    f"{suggestions_prediction}.pickle",
1867                ),
1868                "rb",
1869            ) as f:
1870                predicted_classes = pickle.load(f)
1871            if parameters_update is None:
1872                parameters_update = {}
1873            parameters_update = self._update(
1874                self._predictions().load_parameters(suggestions_prediction),
1875                parameters_update,
1876            )
1877            parameters_update.pop("model")
1878            if suggestion_episodes is None:
1879                suggestion_episodes = [
1880                    os.path.basename(
1881                        os.path.dirname(
1882                            parameters_update["training"]["checkpoint_path"]
1883                        )
1884                    )
1885                ]
1886            task, parameters, mode = self._make_task_prediction(
1887                "_",
1888                load_episode=None,
1889                parameters_update=parameters_update,
1890                data_path=data_path,
1891                file_paths=file_paths,
1892                mode=mode,
1893            )
1894        else:
1895            predicted_classes = None
1896
1897        if len(suggestion_classes) > 0 and predicted_classes is not None:
1898            suggestions = self._make_suggestions(
1899                task,
1900                predicted_error,
1901                predicted_classes,
1902                suggestion_threshold,
1903                suggestion_threshold_diff,
1904                suggestion_hysteresis,
1905                suggestion_episodes,
1906                suggestion_classes,
1907                error_threshold,
1908                min_frames_suggestion,
1909                min_frames_al,
1910                visibility_min_score,
1911                visibility_min_frac,
1912                cut_annotated=cut_annotated,
1913            )
1914            videos = list(suggestions.keys())
1915            for v_id in videos:
1916                times_dict = defaultdict(lambda: defaultdict(lambda: []))
1917                clips = set()
1918                for c in suggestions[v_id]:
1919                    for start, end, ind in suggestions[v_id][c]:
1920                        times_dict[ind][c].append([start, end, 2])
1921                        clips.add(ind)
1922                clips = list(clips)
1923                times_dict = dict(times_dict)
1924                times = [
1925                    [times_dict[ind][c] for c in suggestion_classes] for ind in clips
1926                ]
1927                save_path = self._suggestion_path(v_id, suggestions_name)
1928                Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1929                with open(save_path, "wb") as f:
1930                    pickle.dump((None, suggestion_classes, clips, times), f)
1931
1932        if (
1933            error_episode is not None
1934            or len(exclude_classes) > 0
1935            or len(include_classes) > 0
1936        ):
1937            al_points = self._make_al_points(
1938                task,
1939                predicted_error,
1940                predicted_classes,
1941                exclude_classes,
1942                exclude_threshold,
1943                exclude_threshold_diff,
1944                exclude_hysteresis,
1945                include_classes,
1946                include_threshold,
1947                include_threshold_diff,
1948                include_hysteresis,
1949                error_episode,
1950                error_class,
1951                suggestion_episodes,
1952                error_threshold,
1953                error_threshold_diff,
1954                error_hysteresis,
1955                min_frames_al,
1956                visibility_min_score,
1957                visibility_min_frac,
1958            )
1959        else:
1960            al_points = self._make_al_points_from_suggestions(
1961                suggestions_name,
1962                task,
1963                predicted_classes,
1964                background_threshold,
1965                visibility_min_score,
1966                visibility_min_frac,
1967                num_behaviors=len(task.behaviors_dict()),
1968            )
1969        save_path = self._al_points_path(suggestions_name)
1970        Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1971        with open(save_path, "wb") as f:
1972            pickle.dump(al_points, f)
1973
1974        meta_parameters = {
1975            "error_episode": error_episode,
1976            "error_load_epoch": error_load_epoch,
1977            "error_class": error_class,
1978            "suggestion_episode": suggestion_episodes,
1979            "suggestion_load_epoch": suggestion_load_epoch,
1980            "suggestion_classes": suggestion_classes,
1981            "error_threshold": error_threshold,
1982            "error_threshold_diff": error_threshold_diff,
1983            "error_hysteresis": error_hysteresis,
1984            "suggestion_threshold": suggestion_threshold,
1985            "suggestion_threshold_diff": suggestion_threshold_diff,
1986            "suggestion_hysteresis": suggestion_hysteresis,
1987            "min_frames_suggestion": min_frames_suggestion,
1988            "min_frames_al": min_frames_al,
1989            "visibility_min_score": visibility_min_score,
1990            "visibility_min_frac": visibility_min_frac,
1991            "augment_n": augment_n,
1992            "exclude_classes": exclude_classes,
1993            "exclude_threshold": exclude_threshold,
1994            "exclude_threshold_diff": exclude_threshold_diff,
1995            "exclude_hysteresis": exclude_hysteresis,
1996        }
1997        self._save_suggestions(suggestions_name, {}, meta_parameters)
1998        if data_path is not None or file_paths is not None or remove_saved_features:
1999            self._remove_stores(parameters)
2000        print(f"\n")
2001
2002    def _generate_similarity_score(
2003        self,
2004        prediction_name: str,
2005        target_video_id: str,
2006        target_clip: str,
2007        target_start: int,
2008        target_end: int,
2009    ) -> Dict:
2010        with open(
2011            os.path.join(
2012                self.project_path,
2013                "results",
2014                "predictions",
2015                f"{prediction_name}.pickle",
2016            ),
2017            "rb",
2018        ) as f:
2019            prediction = pickle.load(f)
2020        target = prediction[target_video_id][target_clip][:, target_start:target_end]
2021        score_dict = defaultdict(lambda: {})
2022        for video_id in prediction:
2023            for clip_id in prediction[video_id]:
2024                score_dict[video_id][clip_id] = torch.cdist(
2025                    target.T, prediction[video_id][score_dict].T
2026                ).min(0)
2027        return score_dict
2028
2029    def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict:
2030        """Suggest intervals from a score dictionary.
2031
2032        Parameters
2033        ----------
2034        score_dict : dict
2035            a dictionary containing scores for intervals
2036        min_length : int
2037            minimum length of intervals to suggest
2038        n_intervals : int
2039            number of intervals to suggest
2040
2041        Returns
2042        -------
2043        intervals : dict
2044            a dictionary of suggested intervals
2045
2046        """
2047        interval_address = {}
2048        interval_value = {}
2049        s = 0
2050        n = 0
2051        for video_id, video_dict in score_dict.items():
2052            for clip_id, value in video_dict.items():
2053                s += value.mean()
2054                n += 1
2055        mean_value = s / n
2056        alpha = 1.75
2057        for it in range(10):
2058            id = 0
2059            interval_address = {}
2060            interval_value = {}
2061            for video_id, video_dict in score_dict.items():
2062                for clip_id, value in video_dict.items():
2063                    res_indices_start, res_indices_end = apply_threshold(
2064                        value,
2065                        threshold=(2 - alpha * (0.9**it)) * mean_value,
2066                        low=True,
2067                        error_mask=None,
2068                        min_frames=min_length,
2069                        smooth_interval=0,
2070                    )
2071                    for start, end in zip(res_indices_start, res_indices_end):
2072                        interval_address[id] = [video_id, clip_id, start, end]
2073                        interval_value[id] = score_dict[video_id][clip_id][
2074                            start:end
2075                        ].mean()
2076                        id += 1
2077            if len(interval_address) >= n_intervals:
2078                break
2079        if len(interval_address) < n_intervals:
2080            warnings.warn(
2081                f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals"
2082            )
2083        sorted_intervals = sorted(
2084            interval_value.items(), key=lambda x: x[1], reverse=True
2085        )
2086        output_intervals = [
2087            interval_address[x[0]]
2088            for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)]
2089        ]
2090        output = defaultdict(lambda: [])
2091        for video_id, clip_id, start, end in output_intervals:
2092            output[video_id].append([start, end, clip_id])
2093        return output
2094
2095    def suggest_intervals_with_similarity(
2096        self,
2097        suggestions_name: str,
2098        prediction_name: str,
2099        target_video_id: str,
2100        target_clip: str,
2101        target_start: int,
2102        target_end: int,
2103        min_length: int = 60,
2104        n_intervals: int = 5,
2105        force: bool = False,
2106    ):
2107        """
2108        Suggest intervals based on similarity to a target interval.
2109
2110        Parameters
2111        ----------
2112        suggestions_name : str
2113            Name of the suggestion.
2114        prediction_name : str
2115            Name of the prediction to use.
2116        target_video_id : str
2117            Video id of the target interval.
2118        target_clip : str
2119            Clip id of the target interval.
2120        target_start : int
2121            Start frame of the target interval.
2122        target_end : int
2123            End frame of the target interval.
2124        min_length : int, default 60
2125            Minimum length of the suggested intervals.
2126        n_intervals : int, default 5
2127            Number of suggested intervals.
2128        force : bool, default False
2129            If True, the suggestion is overwritten if it already exists.
2130
2131        """
2132        self._check_suggestions_validity(suggestions_name, force=force)
2133        print(f"SUGGESTION {suggestions_name}")
2134        score_dict = self._generate_similarity_score(
2135            prediction_name, target_video_id, target_clip, target_start, target_end
2136        )
2137        intervals = self._suggest_intervals_from_dict(
2138            score_dict, min_length, n_intervals
2139        )
2140        suggestions_path = os.path.join(
2141            self.project_path,
2142            "results",
2143            "suggestions",
2144            suggestions_name,
2145        )
2146        if not os.path.exists(suggestions_path):
2147            os.mkdir(suggestions_path)
2148        with open(
2149            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2150        ) as f:
2151            pickle.dump(intervals, f)
2152        meta_parameters = {
2153            "prediction_name": prediction_name,
2154            "min_frames_suggestion": min_length,
2155            "n_intervals": n_intervals,
2156            "target_clip": target_clip,
2157            "target_start": target_start,
2158            "target_end": target_end,
2159        }
2160        self._save_suggestions(suggestions_name, {}, meta_parameters)
2161        print("\n")
2162
2163    def suggest_intervals_with_uncertainty(
2164        self,
2165        suggestions_name: str,
2166        episode_names: List,
2167        load_epochs: List = None,
2168        classes: List = None,
2169        n_frames: int = 10000,
2170        method: str = "least_confidence",
2171        min_length: int = 60,
2172        augment_n: int = 0,
2173        data_path: str = None,
2174        file_paths: Set = None,
2175        parameters_update: Dict = None,
2176        mode: str = "all",
2177        force: bool = False,
2178        remove_saved_features: bool = False,
2179    ) -> None:
2180        """Generate an active learning file based on model uncertainty.
2181
2182        If you provide several episode names, the predicted probabilities will be averaged.
2183
2184        Parameters
2185        ----------
2186        suggestions_name : str
2187            the name of the suggestion
2188        episode_names : list
2189            a list of string episode names to load the models from
2190        load_epochs : list, optional
2191            a list of epoch indices to load the models from (if `None`, the last ones will be used)
2192        classes : list, optional
2193            a list of classes to look at (by default all)
2194        n_frames : int, default 10000
2195            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2196            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2197            with the set parameters)
2198        method : {"least_confidence", "entropy"}
2199            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
2200            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
2201        min_length : int, default 60
2202            the minimum number of frames in one interval
2203        augment_n : int, default 0
2204            the number of augmentations to average the predictions over
2205        data_path : str, optional
2206            the path to a data folder (by default, the project data is used)
2207        file_paths : set, optional
2208            a list of file paths (by default, the project data is used)
2209        parameters_update : dict, optional
2210            a dictionary of parameter updates
2211        mode : {"test", "val", "train", "all"}
2212            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2213            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2214        force : bool, default False
2215            if `True`, existing suggestions with the same name will be overwritten
2216        remove_saved_features : bool, default False
2217            if `True`, the dataset will be deleted after the computation
2218
2219        """
2220        self._check_suggestions_validity(suggestions_name, force=force)
2221        print(f"SUGGESTION {suggestions_name}")
2222        task, parameters, mode, predicted, inference_time, behavior_dict = (
2223            self._make_prediction(
2224                suggestions_name,
2225                episode_names,
2226                load_epochs,
2227                parameters_update,
2228                data_path=data_path,
2229                file_paths=file_paths,
2230                mode=mode,
2231                augment_n=augment_n,
2232                evaluate=False,
2233            )
2234        )
2235        if classes is None:
2236            classes = behavior_dict.values()
2237        episode = self._episodes().get_runs(episode_names[0])[0]
2238        score_tensors = task.generate_uncertainty_score(
2239            classes,
2240            augment_n,
2241            method,
2242            predicted,
2243            self._episode(episode).get_behaviors_dict(),
2244        )
2245        intervals = self._suggest_intervals(
2246            task.dataset(mode), score_tensors, n_frames, min_length
2247        )
2248        for k, v in intervals.items():
2249            l = sum([x[1] - x[0] for x in v])
2250            print(f"{k}: {len(v)} ({l})")
2251        if remove_saved_features:
2252            self._remove_stores(parameters)
2253        suggestions_path = os.path.join(
2254            self.project_path,
2255            "results",
2256            "suggestions",
2257            suggestions_name,
2258        )
2259        if not os.path.exists(suggestions_path):
2260            os.mkdir(suggestions_path)
2261        with open(
2262            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2263        ) as f:
2264            pickle.dump(intervals, f)
2265        meta_parameters = {
2266            "suggestion_episode": episode_names,
2267            "suggestion_load_epoch": load_epochs,
2268            "suggestion_classes": classes,
2269            "min_frames_suggestion": min_length,
2270            "augment_n": augment_n,
2271            "method": method,
2272            "num_frames": n_frames,
2273        }
2274        self._save_suggestions(suggestions_name, {}, meta_parameters)
2275        print("\n")
2276
2277    def suggest_intervals_with_bald(
2278        self,
2279        suggestions_name: str,
2280        episode_name: str,
2281        load_epoch: int = None,
2282        classes: List = None,
2283        n_frames: int = 10000,
2284        num_models: int = 10,
2285        kernel_size: int = 11,
2286        min_length: int = 60,
2287        augment_n: int = 0,
2288        data_path: str = None,
2289        file_paths: Set = None,
2290        parameters_update: Dict = None,
2291        mode: str = "all",
2292        force: bool = False,
2293        remove_saved_features: bool = False,
2294    ):
2295        """Generate an active learning file based on Bayesian Active Learning by Disagreement.
2296
2297        Parameters
2298        ----------
2299        suggestions_name : str
2300            the name of the suggestion
2301        episode_name : str
2302            the name of the episode to load the model from
2303        load_epoch : int, optional
2304            the index of the epoch to load the model from (if `None`, the last one will be used)
2305        classes : list, optional
2306            a list of classes to look at (by default all)
2307        n_frames : int, default 10000
2308            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2309            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2310            with the set parameters)
2311        num_models : int, default 10
2312            the number of dropout masks to apply
2313        kernel_size : int, default 11
2314            the size of the smoothing kernel applied to the discrete results
2315        min_length : int, default 60
2316            the minimum number of frames in one interval
2317        augment_n : int, default 0
2318            the number of augmentations to average the predictions over
2319        data_path : str, optional
2320            the path to a data folder (by default, the project data is used)
2321        file_paths : set, optional
2322            a list of file paths (by default, the project data is used)
2323        parameters_update : dict, optional
2324            a dictionary of parameter updates
2325        mode : {"test", "val", "train", "all"}
2326            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2327            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2328        force : bool, default False
2329            if `True`, existing suggestions with the same name will be overwritten
2330        remove_saved_features : bool, default False
2331            if `True`, the dataset will be deleted after the computation
2332
2333        """
2334        self._check_suggestions_validity(suggestions_name, force=force)
2335        print(f"SUGGESTION {suggestions_name}")
2336        task, parameters, mode = self._make_task_prediction(
2337            suggestions_name,
2338            episode_name,
2339            parameters_update,
2340            load_epoch,
2341            data_path=data_path,
2342            file_paths=file_paths,
2343            mode=mode,
2344        )
2345        if classes is None:
2346            classes = list(task.behaviors_dict().values())
2347        score_tensors = task.generate_bald_score(
2348            classes, augment_n, num_models, kernel_size
2349        )
2350        intervals = self._suggest_intervals(
2351            task.dataset(mode), score_tensors, n_frames, min_length
2352        )
2353        if remove_saved_features:
2354            self._remove_stores(parameters)
2355        suggestions_path = os.path.join(
2356            self.project_path,
2357            "results",
2358            "suggestions",
2359            suggestions_name,
2360        )
2361        if not os.path.exists(suggestions_path):
2362            os.mkdir(suggestions_path)
2363        with open(
2364            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2365        ) as f:
2366            pickle.dump(intervals, f)
2367        meta_parameters = {
2368            "suggestion_episode": episode_name,
2369            "suggestion_load_epoch": load_epoch,
2370            "suggestion_classes": classes,
2371            "min_frames_suggestion": min_length,
2372            "augment_n": augment_n,
2373            "method": f"BALD:{num_models}",
2374            "num_frames": n_frames,
2375        }
2376        self._save_suggestions(suggestions_name, {}, meta_parameters)
2377        print("\n")
2378
2379    def list_episodes(
2380        self,
2381        episode_names: List = None,
2382        value_filter: str = "",
2383        display_parameters: List = None,
2384        print_results: bool = True,
2385    ) -> pd.DataFrame:
2386        """Get a filtered pandas dataframe with episode metadata.
2387
2388        Parameters
2389        ----------
2390        episode_names : list
2391            a list of strings of episode names
2392        value_filter : str
2393            a string of filters to apply; of this general structure:
2394            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
2395            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
2396        display_parameters : list
2397            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2398        print_results : bool, default True
2399            if True, the result will be printed to standard output
2400
2401        Returns
2402        -------
2403        pd.DataFrame
2404            the filtered dataframe
2405
2406        """
2407        episodes = self._episodes().list_episodes(
2408            episode_names, value_filter, display_parameters
2409        )
2410        if print_results:
2411            print("TRAINING EPISODES")
2412            print(episodes)
2413            print("\n")
2414        return episodes
2415
2416    def list_predictions(
2417        self,
2418        episode_names: List = None,
2419        value_filter: str = "",
2420        display_parameters: List = None,
2421        print_results: bool = True,
2422    ) -> pd.DataFrame:
2423        """Get a filtered pandas dataframe with prediction metadata.
2424
2425        Parameters
2426        ----------
2427        episode_names : list
2428            a list of strings of episode names
2429        value_filter : str
2430            a string of filters to apply; of this general structure:
2431            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2432            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2433        display_parameters : list
2434            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2435        print_results : bool, default True
2436            if True, the result will be printed to standard output
2437
2438        Returns
2439        -------
2440        pd.DataFrame
2441            the filtered dataframe
2442
2443        """
2444        predictions = self._predictions().list_episodes(
2445            episode_names, value_filter, display_parameters
2446        )
2447        if print_results:
2448            print("PREDICTIONS")
2449            print(predictions)
2450            print("\n")
2451        return predictions
2452
2453    def list_suggestions(
2454        self,
2455        suggestions_names: List = None,
2456        value_filter: str = "",
2457        display_parameters: List = None,
2458        print_results: bool = True,
2459    ) -> pd.DataFrame:
2460        """Get a filtered pandas dataframe with prediction metadata.
2461
2462        Parameters
2463        ----------
2464        suggestions_names : list
2465            a list of strings of suggestion names
2466        value_filter : str
2467            a string of filters to apply; of this general structure:
2468            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2469            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2470        display_parameters : list
2471            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2472        print_results : bool, default True
2473            if True, the result will be printed to standard output
2474
2475        Returns
2476        -------
2477        pd.DataFrame
2478            the filtered dataframe
2479
2480        """
2481        suggestions = self._suggestions().list_episodes(
2482            suggestions_names, value_filter, display_parameters
2483        )
2484        if print_results:
2485            print("SUGGESTIONS")
2486            print(suggestions)
2487            print("\n")
2488        return suggestions
2489
2490    def list_searches(
2491        self,
2492        search_names: List = None,
2493        value_filter: str = "",
2494        display_parameters: List = None,
2495        print_results: bool = True,
2496    ) -> pd.DataFrame:
2497        """Get a filtered pandas dataframe with hyperparameter search metadata.
2498
2499        Parameters
2500        ----------
2501        search_names : list
2502            a list of strings of search names
2503        value_filter : str
2504            a string of filters to apply; of this general structure:
2505            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2506            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2507        display_parameters : list
2508            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2509        print_results : bool, default True
2510            if True, the result will be printed to standard output
2511
2512        Returns
2513        -------
2514        pd.DataFrame
2515            the filtered dataframe
2516
2517        """
2518        searches = self._searches().list_episodes(
2519            search_names, value_filter, display_parameters
2520        )
2521        if print_results:
2522            print("SEARCHES")
2523            print(searches)
2524            print("\n")
2525        return searches
2526
2527    def get_best_parameters(
2528        self,
2529        search_name: str,
2530        round_to_binary: List = None,
2531    ):
2532        """Get the best parameters found by a search.
2533
2534        Parameters
2535        ----------
2536        search_name : str
2537            the name of the search
2538        round_to_binary : list, default None
2539            a list of parameters to round to binary values
2540
2541        Returns
2542        -------
2543        best_params : dict
2544            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2545
2546        """
2547        params, model = self._searches().get_best_params(
2548            search_name, round_to_binary=round_to_binary
2549        )
2550        params = self._update(params, {"general": {"model_name": model}})
2551        return params
2552
2553    def list_best_parameters(
2554        self, search_name: str, print_results: bool = True
2555    ) -> Dict:
2556        """Get the raw dictionary of best parameters found by a search.
2557
2558        Parameters
2559        ----------
2560        search_name : str
2561            the name of the search
2562        print_results : bool, default True
2563            if True, the result will be printed to standard output
2564
2565        Returns
2566        -------
2567        best_params : dict
2568            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2569
2570        """
2571        params = self._searches().get_best_params_raw(search_name)
2572        if print_results:
2573            print(f"SEARCH RESULTS {search_name}")
2574            for k, v in params.items():
2575                print(f"{k}: {v}")
2576            print("\n")
2577        return params
2578
2579    def plot_episodes(
2580        self,
2581        episode_names: List,
2582        metrics: List | str,
2583        modes: List | str = None,
2584        title: str = None,
2585        episode_labels: List = None,
2586        save_path: str = None,
2587        add_hlines: List = None,
2588        epoch_limits: List = None,
2589        colors: List = None,
2590        add_highpoint_hlines: bool = False,
2591        remove_box: bool = False,
2592        font_size: float = None,
2593        linewidth: float = None,
2594        return_ax: bool = False,
2595    ) -> None:
2596        """Plot episode training curves.
2597
2598        Parameters
2599        ----------
2600        episode_names : list
2601            a list of episode names to plot; to plot to episodes in one line combine them in a list
2602            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
2603        metrics : list
2604            a list of metric to plot
2605        modes : list, optional
2606            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
2607        title : str, optional
2608            title for the plot
2609        episode_labels : list, optional
2610            a list of strings used to label the curves (has to be the same length as episode_names)
2611        save_path : str, optional
2612            the path to save the resulting plot
2613        add_hlines : list, optional
2614            a list of float values (or (value, label) tuples) to mark with horizontal lines
2615        epoch_limits : list, optional
2616            a list of (min, max) tuples to set the x-axis limits for each episode
2617        colors: list, optional
2618            a list of matplotlib colors
2619        add_highpoint_hlines : bool, default False
2620            if `True`, horizontal lines will be added at the highest value of each episode
2621        """
2622
2623        if isinstance(metrics, str):
2624            metrics = [metrics]
2625        if isinstance(modes, str):
2626            modes = [modes]
2627
2628        if font_size is not None:
2629            font = {"size": font_size}
2630            rc("font", **font)
2631        if modes is None:
2632            modes = ["val"]
2633        if add_hlines is None:
2634            add_hlines = []
2635        logs = []
2636        epochs = []
2637        labels = []
2638        if episode_labels is not None:
2639            assert len(episode_labels) == len(episode_names)
2640        for name_i, name in enumerate(episode_names):
2641            log_params = product(metrics, modes)
2642            for metric, mode in log_params:
2643                if episode_labels is not None:
2644                    label = episode_labels[name_i]
2645                else:
2646                    label = deepcopy(name)
2647                if len(modes) != 1:
2648                    label += f"_{mode}"
2649                if len(metrics) != 1:
2650                    label += f"_{metric}"
2651                labels.append(label)
2652                if isinstance(name, Iterable) and not isinstance(name, str):
2653                    epoch_list = defaultdict(lambda: [])
2654                    multi_logs = defaultdict(lambda: [])
2655                    for i, n in enumerate(name):
2656                        runs = self._episodes().get_runs(n)
2657                        if len(runs) > 1:
2658                            for run in runs:
2659                                if "::" in run:
2660                                    index = run.split("::")[-1]
2661                                else:
2662                                    index = run.split("#")[-1]
2663                                if multi_logs[index] == []:
2664                                    if multi_logs["null"] is None:
2665                                        raise RuntimeError(
2666                                            "The run indices are not consistent across episodes!"
2667                                        )
2668                                    else:
2669                                        multi_logs[index] += multi_logs["null"]
2670                                multi_logs[index] += list(
2671                                    self._episode(run).get_metric_log(mode, metric)
2672                                )
2673                                start = (
2674                                    0
2675                                    if len(epoch_list[index]) == 0
2676                                    else epoch_list[index][-1]
2677                                )
2678                                epoch_list[index] += [
2679                                    x + start
2680                                    for x in self._episode(run).get_epoch_list(mode)
2681                                ]
2682                            multi_logs["null"] = None
2683                        else:
2684                            if len(multi_logs.keys()) > 1:
2685                                raise RuntimeError(
2686                                    "Cannot plot a single-run episode after a multi-run episode!"
2687                                )
2688                            multi_logs["null"] += list(
2689                                self._episode(n).get_metric_log(mode, metric)
2690                            )
2691                            start = (
2692                                0
2693                                if len(epoch_list["null"]) == 0
2694                                else epoch_list["null"][-1]
2695                            )
2696                            epoch_list["null"] += [
2697                                x + start for x in self._episode(n).get_epoch_list(mode)
2698                            ]
2699                    if len(multi_logs.keys()) == 1:
2700                        log = multi_logs["null"]
2701                        epochs.append(epoch_list["null"])
2702                    else:
2703                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
2704                        epochs.append(
2705                            tuple([v for k, v in epoch_list.items() if k != "null"])
2706                        )
2707                else:
2708                    runs = self._episodes().get_runs(name)
2709                    if len(runs) > 1:
2710                        log = []
2711                        for run in runs:
2712                            tracked_metrics = self._episode(run).get_metrics()
2713                            if metric in tracked_metrics:
2714                                log.append(
2715                                    list(
2716                                        self._episode(run).get_metric_log(mode, metric)
2717                                    )
2718                                )
2719                            else:
2720                                relevant = []
2721                                for m in tracked_metrics:
2722                                    m_split = m.split("_")
2723                                    if (
2724                                        "_".join(m_split[:-1]) == metric
2725                                        and m_split[-1].isnumeric()
2726                                    ):
2727                                        relevant.append(m)
2728                                if len(relevant) == 0:
2729                                    raise ValueError(
2730                                        f"The {metric} metric was not tracked at {run}"
2731                                    )
2732                                arr = 0
2733                                for m in relevant:
2734                                    arr += self._episode(run).get_metric_log(mode, m)
2735                                arr /= len(relevant)
2736                                log.append(list(arr))
2737                        log = tuple(log)
2738                        epochs.append(
2739                            tuple(
2740                                [
2741                                    self._episode(run).get_epoch_list(mode)
2742                                    for run in runs
2743                                ]
2744                            )
2745                        )
2746                    else:
2747                        tracked_metrics = self._episode(name).get_metrics()
2748                        if metric in tracked_metrics:
2749                            log = list(self._episode(name).get_metric_log(mode, metric))
2750                        else:
2751                            relevant = []
2752                            for m in tracked_metrics:
2753                                m_split = m.split("_")
2754                                if (
2755                                    "_".join(m_split[:-1]) == metric
2756                                    and m_split[-1].isnumeric()
2757                                ):
2758                                    relevant.append(m)
2759                            if len(relevant) == 0:
2760                                raise ValueError(
2761                                    f"The {metric} metric was not tracked at {name}"
2762                                )
2763                            arr = 0
2764                            for m in relevant:
2765                                arr += self._episode(name).get_metric_log(mode, m)
2766                            arr /= len(relevant)
2767                            log = list(arr)
2768                        epochs.append(self._episode(name).get_epoch_list(mode))
2769                logs.append(log)
2770        # if episode_labels is not None:
2771        #     print(f'{len(episode_labels)=}, {len(logs)=}')
2772        #     if len(episode_labels) != len(logs):
2773
2774        #         raise ValueError(
2775        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
2776        #             f"curves ({len(logs)})!"
2777        #         )
2778        #     else:
2779        #         labels = episode_labels
2780        if colors is None:
2781            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
2782        if len(colors) != len(logs):
2783            raise ValueError(
2784                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
2785            )
2786        f, ax = plt.subplots()
2787        length = 0
2788        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
2789            if type(log) is list:
2790                if len(log) > length:
2791                    length = len(log)
2792                ax.plot(
2793                    epoch_list,
2794                    log,
2795                    label=label,
2796                    color=color,
2797                )
2798                if add_highpoint_hlines:
2799                    ax.axhline(np.max(log), linestyle="dashed", color=color)
2800            else:
2801                for l, xx in zip(log, epoch_list):
2802                    if len(l) > length:
2803                        length = len(l)
2804                    ax.plot(
2805                        xx,
2806                        l,
2807                        color=color,
2808                        alpha=0.2,
2809                    )
2810                if not all([len(x) == len(log[0]) for x in log]):
2811                    warnings.warn(
2812                        f"Got logs with unequal lengths in parallel runs for {label}"
2813                    )
2814                    log = list(log)
2815                    epoch_list = list(epoch_list)
2816                    for i, x in enumerate(epoch_list):
2817                        to_remove = []
2818                        for j, y in enumerate(x[1:]):
2819                            if y <= x[j - 1]:
2820                                y_ind = x.index(y)
2821                                to_remove += list(range(y_ind, j))
2822                        epoch_list[i] = [
2823                            y for j, y in enumerate(x) if j not in to_remove
2824                        ]
2825                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2826                    length = min([len(x) for x in log])
2827                    for i in range(len(log)):
2828                        log[i] = log[i][:length]
2829                        epoch_list[i] = epoch_list[i][:length]
2830                    if not all([x == epoch_list[0] for x in epoch_list]):
2831                        raise RuntimeError(
2832                            f"Got different epoch indices in parallel runs for {label}"
2833                        )
2834                mean = np.array(log).mean(0)
2835                ax.plot(
2836                    epoch_list[0],
2837                    mean,
2838                    label=label,
2839                    color=color,
2840                    linewidth=linewidth,
2841                )
2842                if add_highpoint_hlines:
2843                    ax.axhline(np.max(mean), linestyle="dashed", color=color)
2844        for x in add_hlines:
2845            label = None
2846            if isinstance(x, Iterable):
2847                x, label = x
2848            ax.axhline(x, label=label)
2849            ax.set_xlim((0, length))
2850
2851        ax.legend()
2852        ax.set_xlabel("epochs")
2853        if len(metrics) == 1:
2854            ax.set_ylabel(metrics[0])
2855        else:
2856            ax.set_ylabel("value")
2857        if title is None:
2858            if len(episode_names) == 1:
2859                title = episode_names[0]
2860            elif len(metrics) == 1:
2861                title = metrics[0]
2862        if epoch_limits is not None:
2863            ax.set_xlim(epoch_limits)
2864        if title is not None:
2865            ax.set_title(title)
2866        if remove_box:
2867            ax.box(False)
2868        if return_ax:
2869            return ax
2870        if save_path is not None:
2871            plt.savefig(save_path)
2872        plt.show()
2873
2874    def update_parameters(
2875        self,
2876        parameters_update: Dict = None,
2877        load_search: str = None,
2878        load_parameters: List = None,
2879        round_to_binary: List = None,
2880    ) -> None:
2881        """Update the parameters in the project config files.
2882
2883        Parameters
2884        ----------
2885        parameters_update : dict, optional
2886            a dictionary of parameter updates
2887        load_search : str, optional
2888            the name of hyperparameter search results to load to config
2889        load_parameters : list, optional
2890            a list of lists of string names of the parameters to load from the searches
2891        round_to_binary : list, optional
2892            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2893
2894        """
2895        keys = [
2896            "general",
2897            "losses",
2898            "metrics",
2899            "ssl",
2900            "training",
2901            "data",
2902        ]
2903        parameters = self._read_parameters(catch_blanks=False)
2904        if parameters_update is not None:
2905            model_params = (
2906                parameters_update.pop("model") if "model" in parameters_update else None
2907            )
2908            feat_params = (
2909                parameters_update.pop("features")
2910                if "features" in parameters_update
2911                else None
2912            )
2913            aug_params = (
2914                parameters_update.pop("augmentations")
2915                if "augmentations" in parameters_update
2916                else None
2917            )
2918
2919            parameters = self._update(parameters, parameters_update)
2920            model_name = parameters["general"]["model_name"]
2921            parameters["model"] = self._open_yaml(
2922                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2923            )
2924            if model_params is not None:
2925                parameters["model"] = self._update(parameters["model"], model_params)
2926            feat_name = parameters["general"]["feature_extraction"]
2927            parameters["features"] = self._open_yaml(
2928                os.path.join(
2929                    self.project_path, "config", "features", f"{feat_name}.yaml"
2930                )
2931            )
2932            if feat_params is not None:
2933                parameters["features"] = self._update(
2934                    parameters["features"], feat_params
2935                )
2936            aug_name = options.extractor_to_transformer[
2937                parameters["general"]["feature_extraction"]
2938            ]
2939            parameters["augmentations"] = self._open_yaml(
2940                os.path.join(
2941                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2942                )
2943            )
2944            if aug_params is not None:
2945                parameters["augmentations"] = self._update(
2946                    parameters["augmentations"], aug_params
2947                )
2948        if load_search is not None:
2949            parameters_update, model_name = self._searches().get_best_params(
2950                load_search, load_parameters, round_to_binary
2951            )
2952            parameters["general"]["model_name"] = model_name
2953            parameters["model"] = self._open_yaml(
2954                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2955            )
2956            parameters = self._update(parameters, parameters_update)
2957        for key in keys:
2958            with open(
2959                os.path.join(self.project_path, "config", f"{key}.yaml"),
2960                "w",
2961                encoding="utf-8",
2962            ) as f:
2963                YAML().dump(parameters[key], f)
2964        model_name = parameters["general"]["model_name"]
2965        model_path = os.path.join(
2966            self.project_path, "config", "model", f"{model_name}.yaml"
2967        )
2968        with open(model_path, "w", encoding="utf-8") as f:
2969            YAML().dump(parameters["model"], f)
2970        features_name = parameters["general"]["feature_extraction"]
2971        features_path = os.path.join(
2972            self.project_path, "config", "features", f"{features_name}.yaml"
2973        )
2974        with open(features_path, "w", encoding="utf-8") as f:
2975            YAML().dump(parameters["features"], f)
2976        aug_name = options.extractor_to_transformer[features_name]
2977        aug_path = os.path.join(
2978            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2979        )
2980        with open(aug_path, "w", encoding="utf-8") as f:
2981            YAML().dump(parameters["augmentations"], f)
2982
2983    def get_summary(
2984        self,
2985        episode_names: list,
2986        method: str = "last",
2987        average: int = 1,
2988        metrics: List = None,
2989        return_values: bool = False,
2990    ) -> Dict:
2991        """Get a summary of episode statistics.
2992
2993        If an episode has multiple runs, the statistics will be aggregated over all of them.
2994
2995        Parameters
2996        ----------
2997        episode_names : str
2998            the names of the episodes
2999        method : ["best", "last"]
3000            the method for choosing the epochs
3001        average : int, default 1
3002            the number of epochs to average over (for each run)
3003        metrics : list, optional
3004            a list of metrics
3005
3006        Returns
3007        -------
3008        statistics : dict
3009            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
3010            and 'std' for the standard deviation
3011
3012        """
3013        runs = []
3014        for episode_name in episode_names:
3015            runs_ep = self._episodes().get_runs(episode_name)
3016            if len(runs_ep) == 0:
3017                raise RuntimeError(
3018                    f"There is no {episode_name} episode in the project memory"
3019                )
3020            runs += runs_ep
3021        if metrics is None:
3022            metrics = self._episode(runs[0]).get_metrics()
3023
3024        values = {m: [] for m in metrics}
3025        for run in runs:
3026            for m in metrics:
3027                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
3028                if method == "best":
3029                    log = sorted(log)
3030                    values[m] += list(log[-average:])
3031                elif method == "last":
3032                    if len(log) == 0:
3033                        episodes = self._episodes().data
3034                        if average == 1 and ("results", m) in episodes.columns:
3035                            values[m] += [episodes.loc[run, ("results", m)]]
3036                        else:
3037                            raise RuntimeError(f"Did not find {m} metric for {run} run")
3038                    values[m] += list(log[-average:])
3039                elif method.startswith("epoch"):
3040                    epoch = int(method[5:]) - 1
3041                    pars = self._episodes().load_parameters(run)
3042                    step = int(pars["training"]["validation_interval"])
3043                    values[m] += [log[epoch // step]]
3044                else:
3045                    raise ValueError(
3046                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
3047                    )
3048        statistics = defaultdict(lambda: {})
3049        for m, v in values.items():
3050            statistics[m]["mean"] = np.mean(v)
3051            statistics[m]["std"] = np.std(v)
3052        print(f"SUMMARY {episode_names}")
3053        for m, v in statistics.items():
3054            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
3055        print("\n")
3056
3057        return (dict(statistics), values) if return_values else dict(statistics)
3058
3059    @staticmethod
3060    def remove_project(name: str, projects_path: str = None) -> None:
3061        """Remove all project files and experiment records and results.
3062
3063        Parameters
3064        ----------
3065        name : str
3066            the name of the project to remove
3067        projects_path : str, optional
3068            the path to the projects directory (by default the home DLC2Action directory)
3069
3070        """
3071        if projects_path is None:
3072            projects_path = os.path.join(str(Path.home()), "DLC2Action")
3073        project_path = os.path.join(projects_path, name)
3074        if os.path.exists(project_path):
3075            shutil.rmtree(project_path)
3076
3077    def remove_saved_features(
3078        self,
3079        dataset_names: List = None,
3080        exceptions: List = None,
3081        remove_active: bool = False,
3082    ) -> None:
3083        """Remove saved pre-computed dataset feature files.
3084
3085        By default, all features will be deleted.
3086        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
3087        while training or inference is happening though.
3088
3089        Parameters
3090        ----------
3091        dataset_names : list, optional
3092            a list of dataset names to delete (by default all names are added)
3093        exceptions : list, optional
3094            a list of dataset names to not be deleted
3095        remove_active : bool, default False
3096            if `False`, datasets used by unfinished episodes will not be deleted
3097
3098        """
3099        print("Removing datasets...")
3100        if dataset_names is None:
3101            dataset_names = []
3102        if exceptions is None:
3103            exceptions = []
3104        if not remove_active:
3105            exceptions += self._episodes().get_active_datasets()
3106        dataset_path = os.path.join(self.project_path, "saved_datasets")
3107        if os.path.exists(dataset_path):
3108            if dataset_names == []:
3109                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
3110
3111            to_remove = [
3112                x
3113                for x in dataset_names
3114                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
3115            ]
3116            if len(to_remove) > 2:
3117                to_remove = tqdm(to_remove)
3118            for dataset in to_remove:
3119                shutil.rmtree(os.path.join(dataset_path, dataset))
3120            to_remove = [
3121                f"{x}.pickle"
3122                for x in dataset_names
3123                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
3124                and x not in exceptions
3125            ]
3126            for dataset in to_remove:
3127                os.remove(os.path.join(dataset_path, dataset))
3128            names = self._saved_datasets().dataset_names()
3129            self._saved_datasets().remove(names)
3130        print("\n")
3131
3132    def remove_extra_checkpoints(
3133        self, episode_names: List = None, exceptions: List = None
3134    ) -> None:
3135        """Remove intermediate model checkpoint files (only leave the files for the last epoch).
3136
3137        By default, all intermediate checkpoints will be deleted.
3138        Files in the model folder that are not associated with any record in the meta files are also deleted.
3139
3140        Parameters
3141        ----------
3142        episode_names : list, optional
3143            a list of episode names to clean (by default all names are added)
3144        exceptions : list, optional
3145            a list of episode names to not clean
3146
3147        """
3148        model_path = os.path.join(self.project_path, "results", "model")
3149        try:
3150            all_names = self._episodes().data.index
3151        except:
3152            all_names = os.listdir(model_path)
3153        if episode_names is None:
3154            episode_names = all_names
3155        if exceptions is None:
3156            exceptions = []
3157        to_remove = [x for x in episode_names if x not in exceptions]
3158        folders = os.listdir(model_path)
3159        for folder in folders:
3160            if folder not in all_names:
3161                shutil.rmtree(os.path.join(model_path, folder))
3162            elif folder in to_remove:
3163                files = os.listdir(os.path.join(model_path, folder))
3164                for file in sorted(files)[:-1]:
3165                    os.remove(os.path.join(model_path, folder, file))
3166
3167    def remove_search(self, search_name: str) -> None:
3168        """Remove a hyperparameter search record.
3169
3170        Parameters
3171        ----------
3172        search_name : str
3173            the name of the search to remove
3174
3175        """
3176        self._searches().remove_episode(search_name)
3177        graph_path = os.path.join(self.project_path, "results", "searches", search_name)
3178        if os.path.exists(graph_path):
3179            shutil.rmtree(graph_path)
3180
3181    def remove_suggestion(self, suggestion_name: str) -> None:
3182        """Remove a suggestion record.
3183
3184        Parameters
3185        ----------
3186        suggestion_name : str
3187            the name of the suggestion to remove
3188
3189        """
3190        self._suggestions().remove_episode(suggestion_name)
3191        suggestion_path = os.path.join(
3192            self.project_path, "results", "suggestions", suggestion_name
3193        )
3194        if os.path.exists(suggestion_path):
3195            shutil.rmtree(suggestion_path)
3196
3197    def remove_prediction(self, prediction_name: str) -> None:
3198        """Remove a prediction record.
3199
3200        Parameters
3201        ----------
3202        prediction_name : str
3203            the name of the prediction to remove
3204
3205        """
3206        self._predictions().remove_episode(prediction_name)
3207        prediction_path = self.prediction_path(prediction_name)
3208        if os.path.exists(prediction_path):
3209            shutil.rmtree(prediction_path)
3210
3211    def check_prediction_exists(self, prediction_name: str) -> str | None:
3212        """Check if a prediction exists.
3213
3214        Parameters
3215        ----------
3216        prediction_name : str
3217            the name of the prediction to check
3218
3219        Returns
3220        -------
3221        str | None
3222            the path to the prediction if it exists, `None` otherwise
3223
3224        """
3225        prediction_path = self.prediction_path(prediction_name)
3226        if os.path.exists(prediction_path):
3227            return prediction_path
3228        return None
3229
3230    def remove_episode(self, episode_name: str) -> None:
3231        """Remove all model, logs and metafile records related to an episode.
3232
3233        Parameters
3234        ----------
3235        episode_name : str
3236            the name of the episode to remove
3237
3238        """
3239        runs = self._episodes().get_runs(episode_name)
3240        runs.append(episode_name)
3241        for run in runs:
3242            self._episodes().remove_episode(run)
3243            model_path = os.path.join(self.project_path, "results", "model", run)
3244            if os.path.exists(model_path):
3245                shutil.rmtree(model_path)
3246            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
3247            if os.path.exists(log_path):
3248                os.remove(log_path)
3249
3250    @abstractmethod
3251    def _reformat_results(res: dict, classes: dict, exclusive=False):
3252        """Add classes to micro metrics in results from evaluation"""
3253        results = deepcopy(res)
3254        for key in results.keys():
3255            if isinstance(results[key], list):
3256                if exclusive and len(classes) == len(results[key]) + 1:
3257                    other_ind = list(classes.keys())[
3258                        list(classes.values()).index("other")
3259                    ]
3260                    classes = {
3261                        (i if i < other_ind else i - 1): c
3262                        for i, c in classes.items()
3263                        if i != other_ind
3264                    }
3265                assert len(results[key]) == len(
3266                    classes
3267                ), f"Results for {key} have {len(results[key])} values, but {len(classes)} classes were provided!"
3268                results[key] = {
3269                    classes[i]: float(v) for i, v in enumerate(results[key])
3270                }
3271        return results
3272
3273    def prune_unfinished(self, exceptions: List = None) -> List:
3274        """Remove all interrupted episodes.
3275
3276        Remove all episodes that either don't have a log file or have less epochs in the log file than in
3277        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
3278        currently running!
3279
3280        Parameters
3281        ----------
3282        exceptions : list
3283            the episodes to keep even if they are interrupted
3284
3285        Returns
3286        -------
3287        pruned : list
3288            a list of the episode names that were pruned
3289
3290        """
3291        if exceptions is None:
3292            exceptions = []
3293        unfinished = self._episodes().unfinished_episodes()
3294        unfinished = [x for x in unfinished if x not in exceptions]
3295        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
3296        unfinished += [
3297            x for x in model_folders if x not in self._episodes().list_episodes().index
3298        ]
3299        print(f"PRUNING {unfinished}")
3300        for episode_name in unfinished:
3301            self.remove_episode(episode_name)
3302        print(f"\n")
3303        return unfinished
3304
3305    def prediction_path(self, prediction_name: str) -> str:
3306        """Get the path where prediction files are saved.
3307
3308        Parameters
3309        ----------
3310        prediction_name : str
3311            name of the prediction
3312
3313        Returns
3314        -------
3315        prediction_path : str
3316            the file path
3317
3318        """
3319        return os.path.join(
3320            self.project_path, "results", "predictions", f"{prediction_name}"
3321        )
3322
3323    def suggestion_path(self, suggestion_name: str) -> str:
3324        """Get the path where suggestion files are saved.
3325
3326        Parameters
3327        ----------
3328        suggestion_name : str
3329            name of the prediction
3330
3331        Returns
3332        -------
3333        suggestion_path : str
3334            the file path
3335
3336        """
3337        return os.path.join(
3338            self.project_path, "results", "suggestions", f"{suggestion_name}"
3339        )
3340
3341    @classmethod
3342    def print_data_types(cls):
3343        """Print available data types."""
3344        print("DATA TYPES:")
3345        for key, value in cls.data_types().items():
3346            print(f"{key}:")
3347            print(value.__doc__)
3348
3349    @classmethod
3350    def print_annotation_types(cls):
3351        """Print available annotation types."""
3352        print("ANNOTATION TYPES:")
3353        for key, value in cls.annotation_types():
3354            print(f"{key}:")
3355            print(value.__doc__)
3356
3357    @staticmethod
3358    def data_types() -> List:
3359        """Get available data types.
3360
3361        Returns
3362        -------
3363        data_types : list
3364            available data types
3365
3366        """
3367        return options.input_stores
3368
3369    @staticmethod
3370    def annotation_types() -> List:
3371        """Get available annotation types.
3372
3373        Returns
3374        -------
3375        list
3376            available annotation types
3377
3378        """
3379        return options.annotation_stores
3380
3381    def _save_mask(self, file: Dict, mask_name: str):
3382        """Save a mask file.
3383
3384        Parameters
3385        ----------
3386        file : dict
3387            the mask file data to save
3388        mask_name : str
3389            the name of the mask file
3390
3391        """
3392        if not os.path.exists(self._mask_path()):
3393            os.mkdir(self._mask_path())
3394        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f:
3395            pickle.dump(file, f)
3396
3397    def _load_mask(self, mask_name: str) -> Dict:
3398        """Load a mask file.
3399
3400        Parameters
3401        ----------
3402        mask_name : str
3403            the name of the mask file to load
3404
3405        Returns
3406        -------
3407        mask : dict
3408            the loaded mask data
3409
3410        """
3411        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f:
3412            data = pickle.load(f)
3413        return data
3414
3415    def _thresholds(self) -> DecisionThresholds:
3416        """Get the decision thresholds meta object.
3417
3418        Returns
3419        -------
3420        thresholds : DecisionThresholds
3421            the decision thresholds meta object
3422
3423        """
3424        return DecisionThresholds(self._thresholds_path())
3425
3426    def _episodes(self) -> SavedRuns:
3427        """Get the episodes meta object.
3428
3429        Returns
3430        -------
3431        episodes : SavedRuns
3432            the episodes meta object
3433
3434        """
3435        try:
3436            return SavedRuns(self._episodes_path(), self.project_path)
3437        except:
3438            self.load_metadata_backup()
3439            return SavedRuns(self._episodes_path(), self.project_path)
3440
3441    def _suggestions(self) -> Suggestions:
3442        """Get the suggestions meta object.
3443
3444        Returns
3445        -------
3446        suggestions : Suggestions
3447            the suggestions meta object
3448
3449        """
3450        try:
3451            return Suggestions(self._suggestions_path(), self.project_path)
3452        except:
3453            self.load_metadata_backup()
3454            return Suggestions(self._suggestions_path(), self.project_path)
3455
3456    def _predictions(self) -> SavedRuns:
3457        """Get the predictions meta object.
3458
3459        Returns
3460        -------
3461        predictions : SavedRuns
3462            the predictions meta object
3463
3464        """
3465        try:
3466            return SavedRuns(self._predictions_path(), self.project_path)
3467        except:
3468            self.load_metadata_backup()
3469            return SavedRuns(self._predictions_path(), self.project_path)
3470
3471    def _saved_datasets(self) -> SavedStores:
3472        """Get the datasets meta object.
3473
3474        Returns
3475        -------
3476        datasets : SavedStores
3477            the datasets meta object
3478
3479        """
3480        try:
3481            return SavedStores(self._saved_datasets_path())
3482        except:
3483            self.load_metadata_backup()
3484            return SavedStores(self._saved_datasets_path())
3485
3486    def _prediction(self, name: str) -> Run:
3487        """Get a prediction meta object.
3488
3489        Parameters
3490        ----------
3491        name : str
3492            episode name
3493
3494        Returns
3495        -------
3496        prediction : Run
3497            the prediction meta object
3498
3499        """
3500        try:
3501            return Run(name, self.project_path, meta_path=self._predictions_path())
3502        except:
3503            self.load_metadata_backup()
3504            return Run(name, self.project_path, meta_path=self._predictions_path())
3505
3506    def _episode(self, name: str) -> Run:
3507        """Get an episode meta object.
3508
3509        Parameters
3510        ----------
3511        name : str
3512            episode name
3513
3514        Returns
3515        -------
3516        episode : Run
3517            the episode meta object
3518
3519        """
3520        try:
3521            return Run(name, self.project_path, meta_path=self._episodes_path())
3522        except:
3523            self.load_metadata_backup()
3524            return Run(name, self.project_path, meta_path=self._episodes_path())
3525
3526    def _searches(self) -> Searches:
3527        """Get the hyperparameter search meta object.
3528
3529        Returns
3530        -------
3531        searches : Searches
3532            the searches meta object
3533
3534        """
3535        try:
3536            return Searches(self._searches_path(), self.project_path)
3537        except:
3538            self.load_metadata_backup()
3539            return Searches(self._searches_path(), self.project_path)
3540
3541    def _update_configs(self) -> None:
3542        """Update the project config files with newly added files and parameters.
3543
3544        This method updates the project configuration with the data path and copies
3545        any new configuration files from the original package to the project.
3546
3547        """
3548        self.update_parameters({"data": {"data_path": self.data_path}})
3549        folders = ["augmentations", "features", "model"]
3550        original_path = os.path.join(
3551            os.path.dirname(os.path.dirname(__file__)), "config"
3552        )
3553        project_path = os.path.join(self.project_path, "config")
3554        filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")]
3555        for folder in folders:
3556            filenames += [
3557                os.path.join(folder, x)
3558                for x in os.listdir(os.path.join(original_path, folder))
3559            ]
3560        filenames.append(os.path.join("data", f"{self.data_type}.yaml"))
3561        if self.annotation_type != "none":
3562            filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml"))
3563        for file in filenames:
3564            filepath_original = os.path.join(original_path, file)
3565            if file.startswith("data") or file.startswith("annotation"):
3566                file = os.path.basename(file)
3567            filepath_project = os.path.join(project_path, file)
3568            if not os.path.exists(filepath_project):
3569                shutil.copy(filepath_original, filepath_project)
3570            else:
3571                original_pars = self._open_yaml(filepath_original)
3572                project_pars = self._open_yaml(filepath_project)
3573                to_remove = []
3574                for key, value in project_pars.items():
3575                    if key not in original_pars:
3576                        if key not in ["data_type", "annotation_type"]:
3577                            to_remove.append(key)
3578                for key in to_remove:
3579                    project_pars.pop(key)
3580                to_remove = []
3581                for key, value in original_pars.items():
3582                    if key in project_pars:
3583                        to_remove.append(key)
3584                for key in to_remove:
3585                    original_pars.pop(key)
3586                project_pars = self._update(project_pars, original_pars)
3587                with open(filepath_project, "w", encoding="utf-8") as f:
3588                    YAML().dump(project_pars, f)
3589
3590    def _update_project(self) -> None:
3591        """Update project files with the current version."""
3592        version_file = self._version_path()
3593        ok = True
3594        if not os.path.exists(version_file):
3595            ok = False
3596        else:
3597            with open(version_file, encoding="utf-8") as f:
3598                project_version = f.read()
3599            if project_version < __version__:
3600                ok = False
3601            elif project_version > __version__:
3602                warnings.warn(
3603                    f"The project expects a higher dlc2action version ({project_version}), please update!"
3604                )
3605        if not ok:
3606            project_config_path = os.path.join(self.project_path, "config")
3607            config_path = os.path.join(
3608                os.path.dirname(os.path.dirname(__path__)), "config"
3609            )
3610            episodes = self._episodes()
3611            folders = ["annotation", "augmentations", "data", "features", "model"]
3612
3613            project_annotation_configs = os.listdir(
3614                os.path.join(project_config_path, "annotation")
3615            )
3616            annotation_configs = os.listdir(os.path.join(config_path, "annotation"))
3617            for ann_config in annotation_configs:
3618                if ann_config not in project_annotation_configs:
3619                    shutil.copytree(
3620                        os.path.join(config_path, "annotation", ann_config),
3621                        os.path.join(project_config_path, "annotation", ann_config),
3622                        dirs_exist_ok=True,
3623                    )
3624                else:
3625                    project_pars = self._open_yaml(
3626                        os.path.join(project_config_path, "annotation", ann_config)
3627                    )
3628                    pars = self._open_yaml(
3629                        os.path.join(config_path, "annotation", ann_config)
3630                    )
3631                    new_keys = set(pars.keys()) - set(project_pars.keys())
3632                    for key in new_keys:
3633                        project_pars[key] = pars[key]
3634                        c = self._get_comment(pars.ca.items.get(key))
3635                        project_pars.yaml_add_eol_comment(c, key=key)
3636                        episodes.update(
3637                            condition=f"general/annotation_type::={ann_config}",
3638                            update={f"data/{key}": pars[key]},
3639                        )
3640
3641    def _initialize_project(
3642        self,
3643        data_type: str,
3644        annotation_type: str = None,
3645        data_path: str = None,
3646        annotation_path: str = None,
3647        copy: bool = True,
3648    ) -> None:
3649        """Initialize a new project."""
3650        if data_type not in self.data_types():
3651            raise ValueError(
3652                f"The {data_type} data type is not available. "
3653                f"Please choose from {self.data_types()}"
3654            )
3655        if annotation_type not in self.annotation_types():
3656            raise ValueError(
3657                f"The {annotation_type} annotation type is not available. "
3658                f"Please choose from {self.annotation_types()}"
3659            )
3660        os.mkdir(self.project_path)
3661        folders = ["results", "saved_datasets", "meta", "config"]
3662        for f in folders:
3663            os.mkdir(os.path.join(self.project_path, f))
3664        results_subfolders = [
3665            "model",
3666            "logs",
3667            "predictions",
3668            "splits",
3669            "searches",
3670            "suggestions",
3671        ]
3672        for sf in results_subfolders:
3673            os.mkdir(os.path.join(self.project_path, "results", sf))
3674        if data_path is not None:
3675            if copy:
3676                os.mkdir(os.path.join(self.project_path, "data"))
3677                shutil.copytree(
3678                    data_path,
3679                    os.path.join(self.project_path, "data"),
3680                    dirs_exist_ok=True,
3681                )
3682                data_path = os.path.join(self.project_path, "data")
3683        if annotation_path is not None:
3684            if copy:
3685                os.mkdir(os.path.join(self.project_path, "annotation"))
3686                shutil.copytree(
3687                    annotation_path,
3688                    os.path.join(self.project_path, "annotation"),
3689                    dirs_exist_ok=True,
3690                )
3691                annotation_path = os.path.join(self.project_path, "annotation")
3692        self._generate_config(
3693            data_type,
3694            annotation_type,
3695            data_path=data_path,
3696            annotation_path=annotation_path,
3697        )
3698        self._generate_meta()
3699
3700    def _read_types(self) -> Tuple[str, str]:
3701        """Get data type and annotation type from existing project files."""
3702        config_path = os.path.join(self.project_path, "config", "general.yaml")
3703        with open(config_path, encoding="utf-8") as f:
3704            pars = YAML().load(f)
3705        data_type = pars["data_type"]
3706        annotation_type = pars["annotation_type"]
3707        return annotation_type, data_type
3708
3709    def _read_paths(self) -> Tuple[str, str]:
3710        """Get data type and annotation type from existing project files."""
3711        config_path = os.path.join(self.project_path, "config", "data.yaml")
3712        with open(config_path, encoding="utf-8") as f:
3713            pars = YAML().load(f)
3714        data_path = pars["data_path"]
3715        annotation_path = pars["annotation_path"]
3716        return annotation_path, data_path
3717
3718    def _generate_config(
3719        self, data_type: str, annotation_type: str, data_path: str, annotation_path: str
3720    ) -> None:
3721        """Initialize the config files."""
3722        default_path = os.path.join(
3723            os.path.dirname(os.path.dirname(__file__)), "config"
3724        )
3725        config_path = os.path.join(self.project_path, "config")
3726        files = ["losses", "metrics", "ssl", "training"]
3727        for f in files:
3728            shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path)
3729        shutil.copytree(
3730            os.path.join(default_path, "model"), os.path.join(config_path, "model")
3731        )
3732        shutil.copytree(
3733            os.path.join(default_path, "features"),
3734            os.path.join(config_path, "features"),
3735        )
3736        shutil.copytree(
3737            os.path.join(default_path, "augmentations"),
3738            os.path.join(config_path, "augmentations"),
3739        )
3740        yaml = YAML()
3741        data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml")
3742        data_params = None
3743        if os.path.exists(data_param_path):
3744            with open(data_param_path, encoding="utf-8") as f:
3745                data_params = yaml.load(f)
3746        if data_params is None:
3747            data_params = {}
3748        if annotation_type is None:
3749            ann_params = {}
3750        else:
3751            ann_param_path = os.path.join(
3752                default_path, "annotation", f"{annotation_type}.yaml"
3753            )
3754            if os.path.exists(ann_param_path):
3755                ann_params = self._open_yaml(ann_param_path)
3756            elif annotation_type == "none":
3757                ann_params = {}
3758            else:
3759                raise ValueError(
3760                    f"The {annotation_type} data type is not available. "
3761                    f"Please choose from {BehaviorDataset.annotation_types()}"
3762                )
3763        if ann_params is None:
3764            ann_params = {}
3765        data_params = self._update(data_params, ann_params)
3766        data_params["data_path"] = data_path
3767        data_params["annotation_path"] = annotation_path
3768        with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f:
3769            yaml.dump(data_params, f)
3770        with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f:
3771            general_params = yaml.load(f)
3772        general_params["data_type"] = data_type
3773        general_params["annotation_type"] = annotation_type
3774        with open(
3775            os.path.join(config_path, "general.yaml"), "w", encoding="utf-8"
3776        ) as f:
3777            yaml.dump(general_params, f)
3778
3779    def _generate_meta(self) -> None:
3780        """Initialize the meta files."""
3781        config_file = os.path.join(self.project_path, "config")
3782        meta_fields = ["time"]
3783        columns = [("meta", field) for field in meta_fields]
3784        episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3785        episodes.to_pickle(self._episodes_path())
3786        meta_fields = ["time", "objective"]
3787        result_fields = ["best_params", "best_value"]
3788        columns = [("meta", field) for field in meta_fields] + [
3789            ("results", field) for field in result_fields
3790        ]
3791        searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3792        searches.to_pickle(self._searches_path())
3793        meta_fields = ["time"]
3794        columns = [("meta", field) for field in meta_fields]
3795        predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3796        predictions.to_pickle(self._predictions_path())
3797        suggestions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3798        suggestions.to_pickle(self._suggestions_path())
3799        with open(os.path.join(config_file, "data.yaml"), encoding="utf-8") as f:
3800            data_keys = list(YAML().load(f).keys())
3801        saved_data = pd.DataFrame(columns=data_keys)
3802        saved_data.to_pickle(self._saved_datasets_path())
3803        pd.DataFrame().to_pickle(self._thresholds_path())
3804        # with open(self._version_path()) as f:
3805        #     f.write(__version__)
3806
3807    def _open_yaml(self, path: str) -> CommentedMap:
3808        """Load a parameter dictionary from a .yaml file."""
3809        with open(path, encoding="utf-8") as f:
3810            data = YAML().load(f)
3811        if data is None:
3812            data = {}
3813        return data
3814
3815    def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7):
3816        """Compare nested dictionaries with 'almost equal' condition."""
3817        ok = True
3818        if u.keys() != d.keys():
3819            ok = False
3820        else:
3821            for k, v in u.items():
3822                if isinstance(v, Mapping):
3823                    ok = self._compare(d[k], v, allow_diff=allow_diff)
3824                else:
3825                    if isinstance(v, float) or isinstance(d[k], float):
3826                        if not isinstance(d[k], float) and not isinstance(d[k], int):
3827                            ok = False
3828                        elif not isinstance(v, float) and not isinstance(v, int):
3829                            ok = False
3830                        elif np.abs(v - d[k]) > allow_diff:
3831                            ok = False
3832                    elif v != d[k]:
3833                        ok = False
3834        return ok
3835
3836    def _check_comment(self, comment_sequence: List) -> bool:
3837        """Check if a comment already exists in a ruamel.yaml comment sequence."""
3838        if comment_sequence is None:
3839            return False
3840        c = self._get_comment(comment_sequence)
3841        if c != "":
3842            return True
3843        else:
3844            return False
3845
3846    def _get_comment(self, comment_sequence: List, strip=True) -> str:
3847        """Get the comment string from a ruamel.yaml comment sequence."""
3848        if comment_sequence is None:
3849            return ""
3850        c = ""
3851        for cm in comment_sequence:
3852            if cm is not None:
3853                if isinstance(cm, Iterable):
3854                    for c in cm:
3855                        if c is not None:
3856                            c = c.value
3857                            break
3858                    break
3859                else:
3860                    c = cm.value
3861                    break
3862        if strip:
3863            c = c.strip()
3864        return c
3865
3866    def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]):
3867        """Update a nested dictionary."""
3868        if "general" in u and "model_name" in u["general"] and "model" in d:
3869            model_name = u["general"]["model_name"]
3870            if d["general"]["model_name"] != model_name:
3871                d["model"] = self._open_yaml(
3872                    os.path.join(
3873                        self.project_path, "config", "model", f"{model_name}.yaml"
3874                    )
3875                )
3876        d_copied = deepcopy(d)
3877        for k, v in u.items():
3878            if (
3879                k in d_copied
3880                and isinstance(d_copied[k], list)
3881                and isinstance(v, Mapping)
3882                and all([isinstance(x, int) for x in v.keys()])
3883            ):
3884                for kk, vv in v.items():
3885                    d_copied[k][kk] = vv
3886            elif (
3887                isinstance(v, Mapping)
3888                and k in d_copied
3889                and isinstance(d_copied[k], Mapping)
3890            ):
3891                if d_copied[k] is None:
3892                    d_k = CommentedMap()
3893                else:
3894                    d_k = d_copied[k]
3895                d_copied[k] = self._update(d_k, v)
3896            else:
3897                d_copied[k] = v
3898                if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None:
3899                    c = self._get_comment(u.ca.items.get(k), strip=False)
3900                    if isinstance(d_copied, CommentedMap) and not self._check_comment(
3901                        d_copied.ca.items.get(k)
3902                    ):
3903                        d_copied.yaml_add_eol_comment(c, key=k)
3904        return d_copied
3905
3906    def _update_with_search(
3907        self,
3908        d: Dict,
3909        search_name: str,
3910        load_parameters: list = None,
3911        round_to_binary: list = None,
3912    ):
3913        """Update a dictionary with best parameters from a hyperparameter search."""
3914        u, _ = self._searches().get_best_params(
3915            search_name, load_parameters, round_to_binary
3916        )
3917        return self._update(d, u)
3918
3919    def _read_parameters(self, catch_blanks=True) -> Dict:
3920        """Compose a parameter dictionary to create a task from the config files."""
3921        config_path = os.path.join(self.project_path, "config")
3922        keys = [
3923            "data",
3924            "general",
3925            "losses",
3926            "metrics",
3927            "ssl",
3928            "training",
3929        ]
3930        parameters = {}
3931        for key in keys:
3932            parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml"))
3933        features = parameters["general"]["feature_extraction"]
3934        parameters["features"] = self._open_yaml(
3935            os.path.join(config_path, "features", f"{features}.yaml")
3936        )
3937        transformer = options.extractor_to_transformer[features]
3938        parameters["augmentations"] = self._open_yaml(
3939            os.path.join(config_path, "augmentations", f"{transformer}.yaml")
3940        )
3941        model = parameters["general"]["model_name"]
3942        parameters["model"] = self._open_yaml(
3943            os.path.join(config_path, "model", f"{model}.yaml")
3944        )
3945        # input = parameters["general"]["input"]
3946        # parameters["model"] = self._open_yaml(
3947        #     os.path.join(config_path, "model", f"{model}.yaml")
3948        # )
3949        if catch_blanks:
3950            blanks = self._get_blanks()
3951            if len(blanks) > 0:
3952                self.list_blanks()
3953                raise ValueError(
3954                    f"Please fill in all the blanks before running experiments"
3955                )
3956        return parameters
3957
3958    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3959        """Select the model and the metrics.
3960
3961        Parameters
3962        ----------
3963        model_name : str, optional
3964            model name; run `project.help("model") to find out more
3965        metric_names : list, optional
3966            a list of metric function names; run `project.help("metrics") to find out more
3967
3968        """
3969        pars = {"general": {}}
3970        if model_name is not None:
3971            assert model_name in options.models
3972            pars["general"]["model_name"] = model_name
3973        if metric_names is not None:
3974            for metric in metric_names:
3975                assert metric in options.metrics
3976            pars["general"]["metric_functions"] = metric_names
3977        self.update_parameters(pars)
3978
3979    def help(self, keyword: str = None):
3980        """Get information on available options.
3981
3982        Parameters
3983        ----------
3984        keyword : str, optional
3985            the keyword for options (run without arguments to see which keywords are available)
3986
3987        """
3988        if keyword is None:
3989            print("AVAILABLE HELP FUNCTIONS:")
3990            print("- Try running `project.help(keyword)` with the following keywords:")
3991            print("    - model: to get more information on available models,")
3992            print(
3993                "    - features: to get more information on available feature extraction modes,"
3994            )
3995            print(
3996                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3997            )
3998            print("    - metrics: to see a list of available metric functions.")
3999            print("    - data: to see help for expected data structure")
4000            print(
4001                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
4002            )
4003            print(
4004                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
4005            )
4006            print(
4007                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
4008            )
4009        elif keyword == "model":
4010            print("MODELS:")
4011            for key, model in options.models.items():
4012                print(f"{key}:")
4013                print(model.__doc__)
4014        elif keyword == "features":
4015            print("FEATURE EXTRACTORS:")
4016            for key, extractor in options.feature_extractors.items():
4017                print(f"{key}:")
4018                print(extractor.__doc__)
4019        elif keyword == "partition_method":
4020            print("PARTITION METHODS:")
4021            print(
4022                BehaviorDataset.partition_train_test_val.__doc__.split(
4023                    "The partitioning method:"
4024                )[1].split("val_frac :")[0]
4025            )
4026        elif keyword == "metrics":
4027            print("METRICS:")
4028            for key, metric in options.metrics.items():
4029                print(f"{key}:")
4030                print(metric.__doc__)
4031        elif keyword == "data":
4032            print("DATA:")
4033            print(f"Video data: {self.data_type}")
4034            print(options.input_stores[self.data_type].__doc__)
4035            print(f"Annotation data: {self.annotation_type}")
4036            print(options.annotation_stores[self.annotation_type].__doc__)
4037            print(
4038                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
4039            )
4040        else:
4041            raise ValueError(f"The {keyword} keyword is not recognized")
4042        print("\n")
4043
4044    def _process_value(self, value):
4045        """Process a configuration value for display.
4046
4047        Parameters
4048        ----------
4049        value : any
4050            the value to process
4051
4052        Returns
4053        -------
4054        processed_value : any
4055            the processed value
4056
4057        """
4058        if isinstance(value, str):
4059            value = f'"{value}"'
4060        elif isinstance(value, CommentedSet):
4061            value = {x for x in value}
4062        return value
4063
4064    def _get_blanks(self):
4065        """Get a list of blank (unset) parameters in the configuration.
4066
4067        Returns
4068        -------
4069        caught : list
4070            a list of parameter keys that have blank values
4071
4072        """
4073        caught = []
4074        parameters = self._read_parameters(catch_blanks=False)
4075        for big_key, big_value in parameters.items():
4076            for key, value in big_value.items():
4077                if value == "???":
4078                    caught.append(
4079                        (big_key, key, self._get_comment(big_value.ca.items.get(key)))
4080                    )
4081        return caught
4082
4083    def list_blanks(self, blanks=None):
4084        """List parameters that need to be filled in.
4085
4086        Parameters
4087        ----------
4088        blanks : list, optional
4089            a list of the parameters to list, if already known
4090
4091        """
4092        if blanks is None:
4093            blanks = self._get_blanks()
4094        if len(blanks) > 0:
4095            to_update = defaultdict(lambda: [])
4096            for b, k, c in blanks:
4097                to_update[b].append((k, c))
4098            print("Before running experiments, please update all the blanks.")
4099            print("To do that, you can run this.")
4100            print("--------------------------------------------------------")
4101            print(f"project.update_parameters(")
4102            print(f"    {{")
4103            for big_key, keys in to_update.items():
4104                print(f'        "{big_key}": {{')
4105                for key, comment in keys:
4106                    print(f'            "{key}": ..., {comment}')
4107                print(f"        }}")
4108            print(f"    }}")
4109            print(")")
4110            print("--------------------------------------------------------")
4111            print("Replace ... with relevant values.")
4112        else:
4113            print("There is no blanks left!")
4114
4115    def list_basic_parameters(
4116        self,
4117    ):
4118        """Get a list of most relevant parameters and code to modify them."""
4119        parameters = self._read_parameters()
4120        print("BASIC PARAMETERS:")
4121        model_name = parameters["general"]["model_name"]
4122        metric_names = parameters["general"]["metric_functions"]
4123        loss_name = parameters["general"]["loss_function"]
4124        feature_extraction = parameters["general"]["feature_extraction"]
4125        print("Here is a list of current parameters.")
4126        print(
4127            "You can copy this code, change the parameters you want to set and run it to update the project config."
4128        )
4129        print("--------------------------------------------------------")
4130        print("project.update_parameters(")
4131        print("    {")
4132        for group in ["general", "data", "training"]:
4133            print(f'        "{group}": {{')
4134            for key in options.basic_parameters[group]:
4135                if key in parameters[group]:
4136                    print(
4137                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
4138                    )
4139            print("        },")
4140        print('        "losses": {')
4141        print(f'            "{loss_name}": {{')
4142        for key in options.basic_parameters["losses"][loss_name]:
4143            if key in parameters["losses"][loss_name]:
4144                print(
4145                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
4146                )
4147        print("            },")
4148        print("        },")
4149        print('        "metrics": {')
4150        for metric in metric_names:
4151            print(f'            "{metric}": {{')
4152            for key in parameters["metrics"][metric]:
4153                print(
4154                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
4155                )
4156            print("            },")
4157        print("        },")
4158        print('        "model": {')
4159        for key in options.basic_parameters["model"][model_name]:
4160            if key in parameters["model"]:
4161                print(
4162                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
4163                )
4164
4165        print("        },")
4166        print('        "features": {')
4167        for key in options.basic_parameters["features"][feature_extraction]:
4168            if key in parameters["features"]:
4169                print(
4170                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
4171                )
4172
4173        print("        },")
4174        print('        "augmentations": {')
4175        for key in options.basic_parameters["augmentations"][feature_extraction]:
4176            if key in parameters["augmentations"]:
4177                print(
4178                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
4179                )
4180        print("        },")
4181        print("    },")
4182        print(")")
4183        print("--------------------------------------------------------")
4184        print("\n")
4185
4186    def _create_record(
4187        self,
4188        episode_name: str,
4189        behaviors_dict: Dict,
4190        load_episode: str = None,
4191        parameters_update: Dict = None,
4192        task: TaskDispatcher = None,
4193        load_epoch: int = None,
4194        load_search: str = None,
4195        load_parameters: list = None,
4196        round_to_binary: list = None,
4197        load_strict: bool = True,
4198        n_seeds: int = 1,
4199    ) -> TaskDispatcher:
4200        """Create a meta data episode record."""
4201        if episode_name in self._episodes().data.index:
4202            return
4203        if type(n_seeds) is not int or n_seeds < 1:
4204            raise ValueError(
4205                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
4206            )
4207        if parameters_update is None:
4208            parameters_update = {}
4209        parameters = self._read_parameters()
4210        parameters = self._update(parameters, parameters_update)
4211        if load_search is not None:
4212            parameters = self._update_with_search(
4213                parameters, load_search, load_parameters, round_to_binary
4214            )
4215        parameters = self._fill(
4216            parameters,
4217            episode_name,
4218            load_episode,
4219            load_epoch=load_epoch,
4220            only_load_model=True,
4221            load_strict=load_strict,
4222            continuing=True,
4223        )
4224        self._save_episode(episode_name, parameters, behaviors_dict)
4225        return task
4226
4227    def _save_thresholds(
4228        self,
4229        episode_names: List,
4230        metric_name: str,
4231        parameters: Dict,
4232        thresholds: List,
4233        load_epochs: List,
4234    ):
4235        """Save optimal decision thresholds in the meta records."""
4236        metric_parameters = parameters["metrics"][metric_name]
4237        self._thresholds().save_thresholds(
4238            episode_names, load_epochs, metric_name, metric_parameters, thresholds
4239        )
4240
4241    def _save_episode(
4242        self,
4243        episode_name: str,
4244        parameters: Dict,
4245        behaviors_dict: Dict,
4246        suppress_validation: bool = False,
4247        training_time: str = None,
4248        norm_stats: Dict = None,
4249    ) -> None:
4250        """Save an episode in the meta files."""
4251        try:
4252            split_info = self._split_info_from_filename(
4253                parameters["training"]["split_path"]
4254            )
4255            parameters["training"]["partition_method"] = split_info["partition_method"]
4256        except:
4257            pass
4258        if norm_stats is not None:
4259            norm_stats = dict(norm_stats)
4260        parameters["training"]["stats"] = norm_stats
4261        self._episodes().save_episode(
4262            episode_name,
4263            parameters,
4264            behaviors_dict,
4265            suppress_validation=suppress_validation,
4266            training_time=training_time,
4267        )
4268
4269    def _save_suggestions(
4270        self, suggestions_name: str, parameters: Dict, meta_parameters: Dict
4271    ) -> None:
4272        """Save a suggestion in the meta files."""
4273        self._suggestions().save_suggestion(
4274            suggestions_name, parameters, meta_parameters
4275        )
4276
4277    def _update_episode_results(
4278        self,
4279        episode_name: str,
4280        logs: Tuple,
4281        training_time: str = None,
4282    ) -> None:
4283        """Save the results of a run in the meta files."""
4284        self._episodes().update_episode_results(episode_name, logs, training_time)
4285
4286    def _save_prediction(
4287        self,
4288        prediction_name: str,
4289        predicted: Dict[str, Dict],
4290        parameters: Dict,
4291        task: TaskDispatcher,
4292        mode: str = "test",
4293        embedding: bool = False,
4294        inference_time: str = None,
4295        behavior_dict: List[Dict[str, Any]] = None,
4296    ) -> None:
4297        """Save a prediction in the meta files."""
4298
4299        folder = self.prediction_path(prediction_name)
4300        os.mkdir(folder)
4301        for video_id, prediction in predicted.items():
4302            with open(
4303                os.path.join(
4304                    folder, video_id + f"_{prediction_name}_prediction.pickle"
4305                ),
4306                "wb",
4307            ) as f:
4308                prediction["min_frames"], prediction["max_frames"] = task.dataset(
4309                    mode
4310                ).get_min_max_frames(video_id)
4311                prediction["classes"] = behavior_dict
4312                pickle.dump(prediction, f)
4313
4314        parameters = self._update(
4315            parameters,
4316            {"meta": {"embedding": embedding, "inference_time": inference_time}},
4317        )
4318        self._predictions().save_episode(
4319            prediction_name, parameters, task.behaviors_dict()
4320        )
4321
4322    def _save_search(
4323        self,
4324        search_name: str,
4325        parameters: Dict,
4326        n_trials: int,
4327        best_params: Dict,
4328        best_value: float,
4329        metric: str,
4330        search_space: Dict,
4331    ) -> None:
4332        """Save a hyperparameter search in the meta files."""
4333        self._searches().save_search(
4334            search_name,
4335            parameters,
4336            n_trials,
4337            best_params,
4338            best_value,
4339            metric,
4340            search_space,
4341        )
4342
4343    def _save_stores(self, parameters: Dict) -> None:
4344        """Save a pickled dataset in the meta files."""
4345        name = os.path.basename(parameters["data"]["feature_save_path"])
4346        self._saved_datasets().save_store(name, self._get_data_pars(parameters))
4347        self.create_metadata_backup()
4348
4349    def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None:
4350        """Remove the pre-computed features folder."""
4351        name = os.path.basename(parameters["data"]["feature_save_path"])
4352        if remove_active or name not in self._episodes().get_active_datasets():
4353            self.remove_saved_features([name])
4354
4355    def _check_episode_validity(
4356        self, episode_name: str, allow_doublecolon: bool = False, force: bool = False
4357    ) -> None:
4358        """Check whether the episode name is valid."""
4359        if episode_name.startswith("_"):
4360            raise ValueError(
4361                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4362            )
4363        elif "." in episode_name:
4364            raise ValueError("Names containing '.' cannot be used!")
4365        if not allow_doublecolon and "#" in episode_name:
4366            raise ValueError(
4367                "Names containing '#' are reserved by dlc2action and cannot be used!"
4368            )
4369        if "::" in episode_name:
4370            raise ValueError(
4371                "Names containing '::' are reserved by dlc2action and cannot be used!"
4372            )
4373        if force:
4374            self.remove_episode(episode_name)
4375        elif not self._episodes().check_name_validity(episode_name):
4376            raise ValueError(
4377                f"The {episode_name} name is already taken! Set force=True to overwrite."
4378            )
4379
4380    def _check_search_validity(self, search_name: str, force: bool = False) -> None:
4381        """Check whether the search name is valid."""
4382        if search_name.startswith("_"):
4383            raise ValueError(
4384                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4385            )
4386        elif "." in search_name:
4387            raise ValueError("Names containing '.' cannot be used!")
4388        if force:
4389            self.remove_search(search_name)
4390        elif not self._searches().check_name_validity(search_name):
4391            raise ValueError(f"The {search_name} name is already taken!")
4392
4393    def _check_prediction_validity(
4394        self, prediction_name: str, force: bool = False
4395    ) -> None:
4396        """Check whether the prediction name is valid."""
4397        if prediction_name.startswith("_"):
4398            raise ValueError(
4399                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4400            )
4401        elif "." in prediction_name:
4402            raise ValueError("Names containing '.' cannot be used!")
4403        if force:
4404            self.remove_prediction(prediction_name)
4405        elif not self._predictions().check_name_validity(prediction_name):
4406            raise ValueError(f"The {prediction_name} name is already taken!")
4407
4408    def _check_suggestions_validity(
4409        self, suggestions_name: str, force: bool = False
4410    ) -> None:
4411        """Check whether the suggestions name is valid."""
4412        if suggestions_name.startswith("_"):
4413            raise ValueError(
4414                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4415            )
4416        elif "." in suggestions_name:
4417            raise ValueError("Names containing '.' cannot be used!")
4418        if force:
4419            self.remove_suggestion(suggestions_name)
4420        elif not self._suggestions().check_name_validity(suggestions_name):
4421            raise ValueError(f"The {suggestions_name} name is already taken!")
4422
4423    def _training_time(self, episode_name: str) -> int:
4424        """Get the training time of an episode in seconds."""
4425        return self._episode(episode_name).training_time()
4426
4427    def _mask_path(self) -> str:
4428        """Get the path to the masks folder.
4429
4430        Returns
4431        -------
4432        path : str
4433            the path to the masks folder
4434
4435        """
4436        return os.path.join(self.project_path, "results", "masks")
4437
4438    def _thresholds_path(self) -> str:
4439        """Get the path to the thresholds meta file.
4440
4441        Returns
4442        -------
4443        path : str
4444            the path to the thresholds meta file
4445
4446        """
4447        return os.path.join(self.project_path, "meta", "thresholds.pickle")
4448
4449    def _episodes_path(self) -> str:
4450        """Get the path to the episodes meta file.
4451
4452        Returns
4453        -------
4454        path : str
4455            the path to the episodes meta file
4456
4457        """
4458        return os.path.join(self.project_path, "meta", "episodes.pickle")
4459
4460    def _suggestions_path(self) -> str:
4461        """Get the path to the suggestions meta file.
4462
4463        Returns
4464        -------
4465        path : str
4466            the path to the suggestions meta file
4467
4468        """
4469        return os.path.join(self.project_path, "meta", "suggestions.pickle")
4470
4471    def _saved_datasets_path(self) -> str:
4472        """Get the path to the datasets meta file.
4473
4474        Returns
4475        -------
4476        path : str
4477            the path to the datasets meta file
4478
4479        """
4480        return os.path.join(self.project_path, "meta", "saved_datasets.pickle")
4481
4482    def _predictions_path(self) -> str:
4483        """Get the path to the predictions meta file.
4484
4485        Returns
4486        -------
4487        path : str
4488            the path to the predictions meta file
4489
4490        """
4491        return os.path.join(self.project_path, "meta", "predictions.pickle")
4492
4493    def _dataset_store_path(self, name: str) -> str:
4494        """Get the path to a specific pickled dataset.
4495
4496        Parameters
4497        ----------
4498        name : str
4499            the name of the dataset
4500
4501        Returns
4502        -------
4503        path : str
4504            the path to the dataset file
4505
4506        """
4507        return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle")
4508
4509    def _al_points_path(self, suggestions_name: str) -> str:
4510        """Get the path to an active learning intervals file.
4511
4512        Parameters
4513        ----------
4514        suggestions_name : str
4515            the name of the suggestions
4516
4517        Returns
4518        -------
4519        path : str
4520            the path to the active learning points file
4521
4522        """
4523        path = os.path.join(
4524            self.project_path,
4525            "results",
4526            "suggestions",
4527            suggestions_name,
4528            f"{suggestions_name}_al_points.pickle",
4529        )
4530        return path
4531
4532    def _suggestion_path(self, v_id: str, suggestions_name: str) -> str:
4533        """Get the path to a suggestion file.
4534
4535        Parameters
4536        ----------
4537        v_id : str
4538            the video ID
4539        suggestions_name : str
4540            the name of the suggestions
4541
4542        Returns
4543        -------
4544        path : str
4545            the path to the suggestion file
4546
4547        """
4548        path = os.path.join(
4549            self.project_path,
4550            "results",
4551            "suggestions",
4552            suggestions_name,
4553            f"{v_id}_suggestion.pickle",
4554        )
4555        return path
4556
4557    def _searches_path(self) -> str:
4558        """Get the path to the hyperparameter search meta file.
4559
4560        Returns
4561        -------
4562        path : str
4563            the path to the searches meta file
4564
4565        """
4566        return os.path.join(self.project_path, "meta", "searches.pickle")
4567
4568    def _search_path(self, name: str) -> str:
4569        """Get the default path to the graph folder for a specific hyperparameter search.
4570
4571        Parameters
4572        ----------
4573        name : str
4574            the name of the search
4575
4576        Returns
4577        -------
4578        path : str
4579            the path to the search folder
4580
4581        """
4582        return os.path.join(self.project_path, "results", "searches", name)
4583
4584    def _version_path(self) -> str:
4585        """Get the path to the version file.
4586
4587        Returns
4588        -------
4589        path : str
4590            the path to the version file
4591
4592        """
4593        return os.path.join(self.project_path, "meta", "version.txt")
4594
4595    def _default_split_file(self, split_info: Dict) -> Optional[str]:
4596        """Generate a path to a split file from split parameters.
4597
4598        Parameters
4599        ----------
4600        split_info : dict
4601            the split parameters dictionary
4602
4603        Returns
4604        -------
4605        split_file_path : str or None
4606            the path to the split file, or None if not applicable
4607
4608        """
4609        if split_info["partition_method"].startswith("time"):
4610            return None
4611        val_frac = split_info["val_frac"]
4612        test_frac = split_info["test_frac"]
4613        split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}'
4614        if not split_info["only_load_annotated"]:
4615            split_name += "_all"
4616        split_name += ".txt"
4617        return os.path.join(self.project_path, "results", "splits", split_name)
4618
4619    def _split_info_from_filename(self, split_name: str) -> Dict:
4620        """Get split parameters from default path to a split file.
4621
4622        Parameters
4623        ----------
4624        split_name : str
4625            the name/path of the split file
4626
4627        Returns
4628        -------
4629        split_info : dict
4630            the split parameters dictionary
4631
4632        """
4633        if split_name is None:
4634            return {}
4635        try:
4636            name = os.path.basename(split_name)[:-4]
4637            split = name.split("_")
4638            if len(split) == 6:
4639                only_load_annotated = False
4640            else:
4641                only_load_annotated = True
4642            len_segment = int(split[3][3:])
4643            overlap = float(split[4][7:])
4644            if overlap > 1:
4645                overlap = int(overlap)
4646            method, val, test = split[:3]
4647            val = float(val[3:-1]) / 100
4648            test = float(test[4:-1]) / 100
4649            return {
4650                "partition_method": method,
4651                "val_frac": val,
4652                "test_frac": test,
4653                "only_load_annotated": only_load_annotated,
4654                "len_segment": len_segment,
4655                "overlap": overlap,
4656            }
4657        except:
4658            return {"partition_method": "file"}
4659
4660    def _fill(
4661        self,
4662        parameters: Dict,
4663        episode_name: str,
4664        load_experiment: str = None,
4665        load_epoch: int = None,
4666        load_strict: bool = True,
4667        only_load_model: bool = False,
4668        continuing: bool = False,
4669        enforce_split_parameters: bool = False,
4670    ) -> Dict:
4671        """Update the parameters from the config files with project specific information.
4672
4673        Fill in the constant file path parameters and generate a unique log file and a model folder.
4674        Fill in the split file if the same split has been run before in the project and change partition method to
4675        from_file.
4676        Fill in saved data path if a dataset with the same data parameters already exists in the project.
4677        If load_experiment is not None, fill in the checkpoint path as well.
4678        The only_load_model training parameter is defined by the corresponding argument.
4679        If continuing is True, new files are not created and all information is loaded from load_experiment.
4680        If prediction is True, log and model files are not created.
4681        The enforce_split_parameters parameter is used to resolve conflicts
4682        between split file path and split parameters when they arise.
4683
4684        Parameters
4685        ----------
4686        parameters : dict
4687            the parameters dictionary to update
4688        episode_name : str
4689            the name of the episode
4690        load_experiment : str, optional
4691            the name of the experiment to load from
4692        load_epoch : int, optional
4693            the epoch to load (by default the last one)
4694        load_strict : bool, default True
4695            if `True`, strict loading is enforced
4696        only_load_model : bool, default False
4697            if `True`, only the model is loaded
4698        continuing : bool, default False
4699            if `True`, continues from existing files
4700        enforce_split_parameters : bool, default False
4701            if `True`, split parameters are enforced
4702
4703        Returns
4704        -------
4705        parameters : dict
4706            the updated parameters dictionary
4707
4708        """
4709        pars = deepcopy(parameters)
4710        if episode_name == "_":
4711            self.remove_episode("_")
4712        log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt")
4713        model_save_path = os.path.join(
4714            self.project_path, "results", "model", episode_name
4715        )
4716        if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)):
4717            raise ValueError(
4718                f"The {episode_name} episode name is already in use! Set force=True to overwrite."
4719            )
4720        keys = ["val_frac", "test_frac", "partition_method"]
4721        if "len_segment" not in pars["general"] and "len_segment" in pars["data"]:
4722            pars["general"]["len_segment"] = pars["data"]["len_segment"]
4723        if "overlap" not in pars["general"] and "overlap" in pars["data"]:
4724            pars["general"]["overlap"] = pars["data"]["overlap"]
4725        if "len_segment" in pars["data"]:
4726            pars["data"].pop("len_segment")
4727        if "overlap" in pars["data"]:
4728            pars["data"].pop("overlap")
4729        split_info = {k: pars["training"][k] for k in keys}
4730        split_info["only_load_annotated"] = pars["general"]["only_load_annotated"]
4731        split_info["len_segment"] = pars["general"]["len_segment"]
4732        split_info["overlap"] = pars["general"]["overlap"]
4733        pars["training"]["log_file"] = log
4734        if not os.path.exists(model_save_path):
4735            os.mkdir(model_save_path)
4736        pars["training"]["model_save_path"] = model_save_path
4737        if load_experiment is not None:
4738            if load_experiment not in self._episodes().data.index:
4739                raise ValueError(f"The {load_experiment} episode does not exist!")
4740            old_episode = self._episode(load_experiment)
4741            old_file = old_episode.split_file()
4742            old_info = self._split_info_from_filename(old_file)
4743            if len(old_info) == 0:
4744                old_info = old_episode.split_info()
4745            if enforce_split_parameters:
4746                if split_info["partition_method"] != "file":
4747                    pars["training"]["split_path"] = self._default_split_file(
4748                        split_info
4749                    )
4750            else:
4751                equal = True
4752                if old_info["partition_method"] != split_info["partition_method"]:
4753                    equal = False
4754                if old_info["partition_method"] != "file":
4755                    if (
4756                        old_info["val_frac"] != split_info["val_frac"]
4757                        or old_info["test_frac"] != split_info["test_frac"]
4758                    ):
4759                        equal = False
4760                if not continuing and not equal:
4761                    warnings.warn(
4762                        f"The partitioning parameters in the loaded experiment ({old_info}) "
4763                        f"are not equal to the current partitioning parameters ({split_info}). "
4764                        f"The current parameters are replaced."
4765                    )
4766                pars["training"]["split_path"] = old_file
4767                for k, v in old_info.items():
4768                    pars["training"][k] = v
4769            pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch)
4770            pars["training"]["load_strict"] = load_strict
4771        else:
4772            pars["training"]["checkpoint_path"] = None
4773            if pars["training"]["partition_method"] == "file":
4774                if (
4775                    "split_path" not in pars["training"]
4776                    or pars["training"]["split_path"] is None
4777                ):
4778                    raise ValueError(
4779                        "The partition_method parameter is set to file but the "
4780                        "split_path parameter is not set!"
4781                    )
4782                elif not os.path.exists(pars["training"]["split_path"]):
4783                    raise ValueError(
4784                        f'The {pars["training"]["split_path"]} split file does not exist'
4785                    )
4786            else:
4787                pars["training"]["split_path"] = self._default_split_file(split_info)
4788        pars["training"]["only_load_model"] = only_load_model
4789        pars["data"]["saved_data_path"] = None
4790        pars["data"]["feature_save_path"] = None
4791        pars_data_copy = self._get_data_pars(pars)
4792        saved_data_name = self._saved_datasets().find_name(pars_data_copy)
4793        if saved_data_name is not None:
4794            pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name)
4795            pars["data"]["feature_save_path"] = self._dataset_store_path(
4796                saved_data_name
4797            ).split(".")[0]
4798        else:
4799            dataset_path = self._dataset_store_path(episode_name)
4800            if os.path.exists(dataset_path):
4801                name, ext = dataset_path.split(".")
4802                i = 0
4803                while os.path.exists(f"{name}_{i}.{ext}"):
4804                    i += 1
4805                dataset_path = f"{name}_{i}.{ext}"
4806            pars["data"]["saved_data_path"] = dataset_path
4807            pars["data"]["feature_save_path"] = dataset_path.split(".")[0]
4808        split_split = pars["training"]["partition_method"].split(":")
4809        random = True
4810        for partition_method in options.partition_methods["fixed"]:
4811            method_split = partition_method.split(":")
4812            if len(split_split) != len(method_split):
4813                continue
4814            equal = True
4815            for x, y in zip(split_split, method_split):
4816                if y.startswith("{"):
4817                    continue
4818                if x != y:
4819                    equal = False
4820                    break
4821            if equal:
4822                random = False
4823                break
4824        if random and os.path.exists(pars["training"]["split_path"]):
4825            pars["training"]["partition_method"] = "file"
4826        pars["general"]["save_dataset"] = True
4827        # Check len_segment for c2f models
4828        if pars["general"]["model_name"].startswith("c2f"):
4829            if int(pars["general"]["len_segment"]) < 512:
4830                raise ValueError(
4831                    "The segment length should be higher than 512 when using one of the C2F models"
4832                )
4833        return pars
4834
4835    def _get_data_pars(self, pars: Dict) -> Dict:
4836        """Get a complete description of the data from a general parameters dictionary.
4837
4838        Parameters
4839        ----------
4840        pars : dict
4841            the general parameters dictionary
4842
4843        Returns
4844        -------
4845        pars_data : dict
4846            the complete data parameters dictionary
4847
4848        """
4849        pars_data_copy = deepcopy(pars["data"])
4850        for par in [
4851            "only_load_annotated",
4852            "exclusive",
4853            "feature_extraction",
4854            "ignored_clips",
4855            "len_segment",
4856            "overlap",
4857        ]:
4858            pars_data_copy[par] = pars["general"].get(par, None)
4859        pars_data_copy.update(pars["features"])
4860        return pars_data_copy
4861
4862    def _make_al_points_from_suggestions(
4863        self,
4864        suggestions_name: str,
4865        task: TaskDispatcher,
4866        predicted_classes: Dict,
4867        background_threshold: Optional[float],
4868        visibility_min_score: float,
4869        visibility_min_frac: float,
4870        num_behaviors: int,
4871    ):
4872        valleys = []
4873        if background_threshold is not None:
4874            for i in range(num_behaviors):
4875                print(f"generating background for behavior {i}...")
4876                valleys.append(
4877                    task.dataset("train").find_valleys(
4878                        predicted_classes,
4879                        threshold=background_threshold,
4880                        visibility_min_score=visibility_min_score,
4881                        visibility_min_frac=visibility_min_frac,
4882                        main_class=i,
4883                        low=True,
4884                        cut_annotated=True,
4885                        min_frames=1,
4886                    )
4887                )
4888        valleys = task.dataset("train").valleys_intersection(valleys)
4889        folder = os.path.join(
4890            self.project_path, "results", "suggestions", suggestions_name
4891        )
4892        os.makedirs(os.path.dirname(folder), exist_ok=True)
4893        res = {}
4894        for file in os.listdir(folder):
4895            video_id = file.split("_suggestion.p")[0]
4896            res[video_id] = []
4897            with open(os.path.join(folder, file), "rb") as f:
4898                data = pickle.load(f)
4899            for clip_id, ind_list in zip(data[2], data[3]):
4900                max_len = max(
4901                    [
4902                        max([x[1] for x in cat_list]) if len(cat_list) > 0 else 0
4903                        for cat_list in ind_list
4904                    ]
4905                )
4906                if max_len == 0:
4907                    continue
4908                arr = torch.zeros(max_len)
4909                for cat_list in ind_list:
4910                    for start, end, amb in cat_list:
4911                        arr[start:end] = 1
4912                if video_id in valleys:
4913                    for start, end, clip in valleys[video_id]:
4914                        if clip == clip_id:
4915                            arr[start:end] = 1
4916                output, indices, counts = torch.unique_consecutive(
4917                    arr > 0, return_inverse=True, return_counts=True
4918                )
4919                long_indices = torch.where(output)[0]
4920                res[video_id] += [
4921                    (
4922                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
4923                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
4924                        clip_id,
4925                    )
4926                    for i in long_indices
4927                ]
4928        return res
4929
4930    def _make_al_points(
4931        self,
4932        task: TaskDispatcher,
4933        predicted_error: torch.Tensor,
4934        predicted_classes: torch.Tensor,
4935        exclude_classes: List,
4936        exclude_threshold: List,
4937        exclude_threshold_diff: List,
4938        exclude_hysteresis: List,
4939        include_classes: List,
4940        include_threshold: List,
4941        include_threshold_diff: List,
4942        include_hysteresis: List,
4943        error_episode: str = None,
4944        error_class: str = None,
4945        suggestion_episodes: List = None,
4946        error_threshold: float = 0.5,
4947        error_threshold_diff: float = 0.1,
4948        error_hysteresis: bool = False,
4949        min_frames_al: int = 30,
4950        visibility_min_score: float = 5,
4951        visibility_min_frac: float = 0.7,
4952    ) -> Dict:
4953        """Generate an active learning file."""
4954        if len(exclude_classes) > 0 or len(include_classes) > 0:
4955            valleys = []
4956            included = None
4957            excluded = None
4958            for class_name, thr, thr_diff, hysteresis in zip(
4959                exclude_classes,
4960                exclude_threshold,
4961                exclude_threshold_diff,
4962                exclude_hysteresis,
4963            ):
4964                episode = self._episodes().get_runs(suggestion_episodes[0])[0]
4965                class_index = self._episode(episode).get_class_ind(class_name)
4966                valleys.append(
4967                    task.dataset("train").find_valleys(
4968                        predicted_classes,
4969                        predicted_error=predicted_error,
4970                        min_frames=min_frames_al,
4971                        threshold=thr,
4972                        visibility_min_score=visibility_min_score,
4973                        visibility_min_frac=visibility_min_frac,
4974                        error_threshold=error_threshold,
4975                        main_class=class_index,
4976                        low=True,
4977                        threshold_diff=thr_diff,
4978                        min_frames_error=min_frames_al,
4979                        hysteresis=hysteresis,
4980                    )
4981                )
4982            if len(valleys) > 0:
4983                included = task.dataset("train").valleys_union(valleys)
4984            valleys = []
4985            for class_name, thr, thr_diff, hysteresis in zip(
4986                include_classes,
4987                include_threshold,
4988                include_threshold_diff,
4989                include_hysteresis,
4990            ):
4991                episode = self._episodes().get_runs(suggestion_episodes[0])[0]
4992                class_index = self._episode(episode).get_class_ind(class_name)
4993                valleys.append(
4994                    task.dataset("train").find_valleys(
4995                        predicted_classes,
4996                        predicted_error=predicted_error,
4997                        min_frames=min_frames_al,
4998                        threshold=thr,
4999                        visibility_min_score=visibility_min_score,
5000                        visibility_min_frac=visibility_min_frac,
5001                        error_threshold=error_threshold,
5002                        main_class=class_index,
5003                        low=False,
5004                        threshold_diff=thr_diff,
5005                        min_frames_error=min_frames_al,
5006                        hysteresis=hysteresis,
5007                    )
5008                )
5009            if len(valleys) > 0:
5010                excluded = task.dataset("train").valleys_union(valleys)
5011            al_points = task.dataset("train").valleys_intersection([included, excluded])
5012        else:
5013            class_index = self._episode(error_episode).get_class_ind(error_class)
5014            print("generating active learning intervals...")
5015            al_points = task.dataset("train").find_valleys(
5016                predicted_error,
5017                min_frames=min_frames_al,
5018                threshold=error_threshold,
5019                visibility_min_score=visibility_min_score,
5020                visibility_min_frac=visibility_min_frac,
5021                main_class=class_index,
5022                low=True,
5023                threshold_diff=error_threshold_diff,
5024                min_frames_error=min_frames_al,
5025                hysteresis=error_hysteresis,
5026            )
5027        for v_id in al_points:
5028            clip_dict = defaultdict(lambda: [])
5029            res = []
5030            for x in al_points[v_id]:
5031                clip_dict[x[-1]].append(x)
5032            for clip_id in clip_dict:
5033                clip_dict[clip_id] = sorted(clip_dict[clip_id])
5034                i = 0
5035                j = 1
5036                while j < len(clip_dict[clip_id]):
5037                    end = clip_dict[clip_id][i][1]
5038                    start = clip_dict[clip_id][j][0]
5039                    if start - end < 30:
5040                        clip_dict[clip_id][i][1] = clip_dict[clip_id][j][1]
5041                    else:
5042                        res.append(clip_dict[clip_id][i])
5043                        i = j
5044                    j += 1
5045                res.append(clip_dict[clip_id][i])
5046            al_points[v_id] = sorted(res)
5047        return al_points
5048
5049    def _make_suggestions(
5050        self,
5051        task: TaskDispatcher,
5052        predicted_error: torch.Tensor,
5053        predicted_classes: torch.Tensor,
5054        suggestion_threshold: List,
5055        suggestion_threshold_diff: List,
5056        suggestion_hysteresis: List,
5057        suggestion_episodes: List = None,
5058        suggestion_classes: List = None,
5059        error_threshold: float = 0.5,
5060        min_frames_suggestion: int = 3,
5061        min_frames_al: int = 30,
5062        visibility_min_score: float = 0,
5063        visibility_min_frac: float = 0.7,
5064        cut_annotated: bool = False,
5065    ) -> Dict:
5066        """Make a suggestions dictionary."""
5067        suggestions = defaultdict(lambda: {})
5068        for class_name, thr, thr_diff, hysteresis in zip(
5069            suggestion_classes,
5070            suggestion_threshold,
5071            suggestion_threshold_diff,
5072            suggestion_hysteresis,
5073        ):
5074            episode = self._episodes().get_runs(suggestion_episodes[0])[0]
5075            class_index = self._episode(episode).get_class_ind(class_name)
5076            print(f"generating suggestions for {class_name}...")
5077            found = task.dataset("train").find_valleys(
5078                predicted_classes,
5079                smooth_interval=2,
5080                predicted_error=predicted_error,
5081                min_frames=min_frames_suggestion,
5082                threshold=thr,
5083                visibility_min_score=visibility_min_score,
5084                visibility_min_frac=visibility_min_frac,
5085                error_threshold=error_threshold,
5086                main_class=class_index,
5087                low=False,
5088                threshold_diff=thr_diff,
5089                min_frames_error=min_frames_al,
5090                hysteresis=hysteresis,
5091                cut_annotated=cut_annotated,
5092            )
5093            for v_id in found:
5094                suggestions[v_id][class_name] = found[v_id]
5095        suggestions = dict(suggestions)
5096        return suggestions
5097
5098    def count_classes(
5099        self,
5100        load_episode: str = None,
5101        parameters_update: Dict = None,
5102        remove_saved_features: bool = False,
5103        bouts: bool = True,
5104    ) -> Dict:
5105        """Get a dictionary of class counts in different modes.
5106
5107        Parameters
5108        ----------
5109        load_episode : str, optional
5110            the episode settings to load
5111        parameters_update : dict, optional
5112            a dictionary of parameter updates (only for "data" and "general" categories)
5113        remove_saved_features : bool, default False
5114            if `True`, the dataset that is used for computation is then deleted
5115        bouts : bool, default False
5116            if `True`, instead of frame counts segment counts are returned
5117
5118        Returns
5119        -------
5120        class_counts : dict
5121            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
5122            class names and values are class counts (in frames)
5123
5124        """
5125        if load_episode is None:
5126            task, parameters = self._make_task_training(
5127                episode_name="_", parameters_update=parameters_update, throwaway=True
5128            )
5129        else:
5130            task, parameters, _ = self._make_task_prediction(
5131                "_",
5132                load_episode=load_episode,
5133                parameters_update=parameters_update,
5134            )
5135        class_counts = task.count_classes(bouts=bouts)
5136        behaviors = task.behaviors_dict()
5137        class_counts = {
5138            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
5139            for kk, vv in class_counts.items()
5140        }
5141        if remove_saved_features:
5142            self._remove_stores(parameters)
5143        return class_counts
5144
5145    def plot_class_distribution(
5146        self,
5147        parameters_update: Dict = None,
5148        frame_cutoff: int = 1,
5149        bout_cutoff: int = 1,
5150        print_full: bool = False,
5151        remove_saved_features: bool = False,
5152        save: str = None,
5153    ) -> None:
5154        """Make a class distribution plot.
5155
5156        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
5157        is created or loaded for the computation with the default parameters).
5158
5159        Parameters
5160        ----------
5161        parameters_update : dict, optional
5162            a dictionary of parameter updates (only for "data" and "general" categories)
5163        frame_cutoff : int, default 1
5164            the minimum number of frames for a segment to be considered
5165        bout_cutoff : int, default 1
5166            the minimum number of bouts for a class to be considered
5167        print_full : bool, default False
5168            if `True`, the full class distribution is printed
5169        remove_saved_features : bool, default False
5170            if `True`, the dataset that is used for computation is then deleted
5171
5172        """
5173        task, parameters = self._make_task_training(
5174            episode_name="_", parameters_update=parameters_update, throwaway=True
5175        )
5176        cutoff = {True: bout_cutoff, False: frame_cutoff}
5177        for bouts in [True, False]:
5178            class_counts = task.count_classes(bouts=bouts)
5179            if print_full:
5180                print("Bouts:" if bouts else "Frames:")
5181                for k, v in class_counts.items():
5182                    if sum(v.values()) != 0:
5183                        print(f"  {k}:")
5184                        values, keys = zip(
5185                            *[
5186                                x
5187                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
5188                                if x[-1] != -100
5189                            ]
5190                        )
5191                        for kk, vv in zip(keys, values):
5192                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
5193            class_counts = {
5194                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
5195                for kk, vv in class_counts.items()
5196            }
5197            for key, d in class_counts.items():
5198                if sum(d.values()) != 0:
5199                    values, keys = zip(
5200                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
5201                    )
5202                    keys = [task.behaviors_dict()[x] for x in keys]
5203                    plt.bar(keys, values)
5204                    plt.title(key)
5205                    plt.xticks(rotation=45, ha="right")
5206                    if bouts:
5207                        plt.ylabel("bouts")
5208                    else:
5209                        plt.ylabel("frames")
5210                    plt.tight_layout()
5211
5212                    if save is None:
5213                        plt.savefig(save)
5214                        plt.close()
5215                    else:
5216                        plt.show()
5217        if remove_saved_features:
5218            self._remove_stores(parameters)
5219
5220    def _generate_mask(
5221        self,
5222        mask_name: str,
5223        perc_annotated: float = 0.1,
5224        parameters_update: Dict = None,
5225        remove_saved_features: bool = False,
5226    ) -> None:
5227        """Generate a real_lens for active learning simulation.
5228
5229        Parameters
5230        ----------
5231        mask_name : str
5232            the name of the real_lens
5233        perc_annotated : float, default 0.1
5234            a
5235
5236        """
5237        print(f"GENERATING {mask_name}")
5238        task, parameters = self._make_task_training(
5239            f"_{mask_name}", parameters_update=parameters_update, throwaway=True
5240        )
5241        val_intervals, val_ids = task.dataset("val").get_intervals()  # 1
5242        unannotated_intervals = task.dataset("train").get_unannotated_intervals()  # 2
5243        unannotated_intervals = task.dataset("val").get_unannotated_intervals(
5244            first_intervals=unannotated_intervals
5245        )
5246        ids = task.dataset("train").get_ids()
5247        mask = {video_id: {} for video_id in ids}
5248        total_all = 0
5249        total_masked = 0
5250        for video_id, clip_ids in ids.items():
5251            for clip_id in clip_ids:
5252                frames = np.ones(task.dataset("train").get_len(video_id, clip_id))
5253                if clip_id in val_intervals[video_id]:
5254                    for start, end in val_intervals[video_id][clip_id]:
5255                        frames[start:end] = 0
5256                if clip_id in unannotated_intervals[video_id]:
5257                    for start, end in unannotated_intervals[video_id][clip_id]:
5258                        frames[start:end] = 0
5259                annotated = np.where(frames)[0]
5260                total_all += len(annotated)
5261                masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :]
5262                total_masked += len(masked)
5263                mask[video_id][clip_id] = self._get_intervals(masked)
5264        file = {
5265            "masked": mask,
5266            "val_intervals": val_intervals,
5267            "val_ids": val_ids,
5268            "unannotated": unannotated_intervals,
5269        }
5270        self._save_mask(file, mask_name)
5271        if remove_saved_features:
5272            self._remove_stores(parameters)
5273        print("\n")
5274        # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames')
5275
5276    def _get_intervals(self, frame_indices: np.ndarray):
5277        """Get a list of intervals from a list of frame indices.
5278
5279        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
5280
5281        Parameters
5282        ----------
5283        frame_indices : np.ndarray
5284            a list of frame indices
5285
5286        Returns
5287        -------
5288        intervals : list
5289            a list of interval boundaries
5290
5291        """
5292        masked_intervals = []
5293        if len(frame_indices) > 0:
5294            breaks = np.where(np.diff(frame_indices) != 1)[0]
5295            start = frame_indices[0]
5296            for k in breaks:
5297                masked_intervals.append([start, frame_indices[k] + 1])
5298                start = frame_indices[k + 1]
5299            masked_intervals.append([start, frame_indices[-1] + 1])
5300        return masked_intervals
5301
5302    def _update_mask_with_uncertainty(
5303        self,
5304        mask_name: str,
5305        episode_name: Union[str, None],
5306        classes: List,
5307        load_epoch: int = None,
5308        n_frames: int = 10000,
5309        method: str = "least_confidence",
5310        min_length: int = 30,
5311        augment_n: int = 0,
5312        parameters_update: Dict = None,
5313    ):
5314        """Update real_lens with frame-wise uncertainty scores for active learning.
5315
5316        Parameters
5317        ----------
5318        mask_name : str
5319            the name of the real_lens
5320        episode_name : str
5321            the name of the episode to load
5322        classes : list
5323            a list of class names or indices; their uncertainty scores will be computed separately and stacked
5324        load_epoch : int, optional
5325            the epoch to load (by default last; if this epoch is not saved the closest checkpoint is chosen)
5326        n_frames : int, default 10000
5327            the number of frames to "annotate"
5328        method : {"least_confidence", "entropy"}
5329            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
5330            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
5331        min_length : int
5332            the minimum length (in frames) of the annotated intervals
5333        augment_n : int, default 0
5334            the number of augmentations to average over
5335        parameters_update : dict, optional
5336            the dictionary used to update the parameters from the config
5337
5338        Returns
5339        -------
5340        score_dicts : dict
5341            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
5342            are score tensors
5343
5344        """
5345        print(f"UPDATING {mask_name}")
5346        task, parameters, _ = self._make_task_prediction(
5347            prediction_name=mask_name,
5348            load_episode=episode_name,
5349            parameters_update=parameters_update,
5350            load_epoch=load_epoch,
5351            mode="train",
5352        )
5353        score_tensors = task.generate_uncertainty_score(classes, augment_n, method)
5354        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
5355        print("\n")
5356
5357    def _update_mask_with_BALD(
5358        self,
5359        mask_name: str,
5360        episode_name: str,
5361        classes: List,
5362        load_epoch: int = None,
5363        augment_n: int = 0,
5364        n_frames: int = 10000,
5365        num_models: int = 10,
5366        kernel_size: int = 11,
5367        min_length: int = 30,
5368        parameters_update: Dict = None,
5369    ):
5370        """Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning.
5371
5372        Parameters
5373        ----------
5374        mask_name : str
5375            the name of the real_lens
5376        episode_name : str
5377            the name of the episode to load
5378        classes : list
5379            a list of class names or indices; their uncertainty scores will be computed separately and stacked
5380        load_epoch : int, optional
5381            the epoch to load (by default last)
5382        augment_n : int, default 0
5383            the number of augmentations to average over
5384        n_frames : int, default 10000
5385            the number of frames to "annotate"
5386        num_models : int, default 10
5387            the number of dropout masks to apply
5388        kernel_size : int, default 11
5389            the size of the smoothing gaussian kernel
5390        min_length : int
5391            the minimum length (in frames) of the annotated intervals
5392        parameters_update : dict, optional
5393            the dictionary used to update the parameters from the config
5394
5395        Returns
5396        -------
5397        score_dicts : dict
5398            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
5399            are score tensors
5400
5401        """
5402        print(f"UPDATING {mask_name}")
5403        task, parameters, mode = self._make_task_prediction(
5404            mask_name,
5405            load_episode=episode_name,
5406            parameters_update=parameters_update,
5407            load_epoch=load_epoch,
5408        )
5409        score_tensors = task.generate_bald_score(
5410            classes, augment_n, num_models, kernel_size
5411        )
5412        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
5413        print("\n")
5414
5415    def _suggest_intervals(
5416        self,
5417        dataset: BehaviorDataset,
5418        score_tensors: Dict,
5419        n_frames: int,
5420        min_length: int,
5421    ) -> Dict:
5422        """Suggest intervals with highest score of total length `n_frames`.
5423
5424        Parameters
5425        ----------
5426        dataset : BehaviorDataset
5427            the dataset
5428        score_tensors : dict
5429            a dictionary where keys are clip ids and values are framewise score tensors
5430        n_frames : int
5431            the number of frames to "annotate"
5432        min_length : int
5433            minimum length of suggested intervals
5434
5435        Returns
5436        -------
5437        active_learning_intervals : Dict
5438            active learning dictionary with suggested intervals
5439
5440        """
5441        video_intervals, _ = dataset.get_intervals()
5442        taken = {
5443            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
5444        }
5445        annotated = dataset.get_annotated_intervals()
5446        for video_id in video_intervals:
5447            for clip_id in video_intervals[video_id]:
5448                taken[video_id][clip_id] = torch.zeros(
5449                    dataset.get_len(video_id, clip_id)
5450                )
5451                if video_id in annotated and clip_id in annotated[video_id]:
5452                    for start, end in annotated[video_id][clip_id]:
5453                        score_tensors[video_id][clip_id][:, start:end] = -10
5454                        taken[video_id][clip_id][int(start) : int(end)] = 1
5455        n_frames = (
5456            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
5457            + n_frames
5458        )
5459        factor = 1
5460        threshold_start = float(
5461            torch.mean(
5462                torch.tensor(
5463                    [
5464                        torch.mean(
5465                            torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
5466                        )
5467                        for x in score_tensors.values()
5468                    ]
5469                )
5470            )
5471        )
5472        while (
5473            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
5474            < n_frames
5475        ):
5476            threshold = threshold_start * factor
5477            intervals = []
5478            interval_scores = []
5479            key1 = list(score_tensors.keys())[0]
5480            key2 = list(score_tensors[key1].keys())[0]
5481            num_scores = score_tensors[key1][key2].shape[0]
5482            for i in range(num_scores):
5483                v_dict = dataset.find_valleys(
5484                    predicted=score_tensors,
5485                    threshold=threshold,
5486                    min_frames=min_length,
5487                    main_class=i,
5488                    low=False,
5489                )
5490                for v_id, interval_list in v_dict.items():
5491                    intervals += [x + [v_id] for x in interval_list]
5492                    interval_scores += [
5493                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
5494                        for start, end, clip_id in interval_list
5495                    ]
5496            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
5497            i = 0
5498            while sum(
5499                [(vv == 1).sum() for v in taken.values() for vv in v.values()]
5500            ) < n_frames and i < len(intervals):
5501                start, end, clip_id, video_id = intervals[i]
5502                i += 1
5503                taken[video_id][clip_id][int(start) : int(end)] = 1
5504            factor *= 0.9
5505            if factor < 0.05:
5506                warnings.warn(f"Could not find enough frames!")
5507                break
5508        active_learning_intervals = {video_id: [] for video_id in video_intervals}
5509        for video_id in taken:
5510            for clip_id in taken[video_id]:
5511                if video_id in annotated and clip_id in annotated[video_id]:
5512                    for start, end in annotated[video_id][clip_id]:
5513                        taken[video_id][clip_id][int(start) : int(end)] = 0
5514                if (taken[video_id][clip_id] == 1).sum() == 0:
5515                    continue
5516                indices = np.where(taken[video_id][clip_id].numpy())[0]
5517                boundaries = self._get_intervals(indices)
5518                active_learning_intervals[video_id] += [
5519                    [start, end, clip_id] for start, end in boundaries
5520                ]
5521        return active_learning_intervals
5522
5523    def _update_mask(
5524        self,
5525        task: TaskDispatcher,
5526        mask_name: str,
5527        score_tensors: Dict,
5528        n_frames: int,
5529        min_length: int,
5530    ) -> None:
5531        """Update the real_lens with intervals with the highest score of total length `n_frames`.
5532
5533        Parameters
5534        ----------
5535        task : TaskDispatcher
5536            the task dispatcher object
5537        mask_name : str
5538            the name of the real_lens
5539        score_tensors : dict
5540            a dictionary where keys are clip ids and values are framewise score tensors
5541        n_frames : int
5542            the number of frames to "annotate"
5543        min_length : int
5544            the minimum length of the annotated intervals
5545
5546        """
5547        mask = self._load_mask(mask_name)
5548        video_intervals, _ = task.dataset("train").get_intervals()
5549        masked = {
5550            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
5551        }
5552        total_masked = 0
5553        total_all = 0
5554        for video_id in video_intervals:
5555            for clip_id in video_intervals[video_id]:
5556                masked[video_id][clip_id] = torch.zeros(
5557                    task.dataset("train").get_len(video_id, clip_id)
5558                )
5559                if (
5560                    video_id in mask["unannotated"]
5561                    and clip_id in mask["unannotated"][video_id]
5562                ):
5563                    for start, end in mask["unannotated"][video_id][clip_id]:
5564                        score_tensors[video_id][clip_id][:, start:end] = -10
5565                        masked[video_id][clip_id][int(start) : int(end)] = 1
5566                if (
5567                    video_id in mask["val_intervals"]
5568                    and clip_id in mask["val_intervals"][video_id]
5569                ):
5570                    for start, end in mask["val_intervals"][video_id][clip_id]:
5571                        score_tensors[video_id][clip_id][:, start:end] = -10
5572                        masked[video_id][clip_id][int(start) : int(end)] = 1
5573                total_all += torch.sum(masked[video_id][clip_id] == 0)
5574                if video_id in mask["masked"] and clip_id in mask["masked"][video_id]:
5575                    # print(f'{real_lens["masked"][video_id][clip_id]=}')
5576                    for start, end in mask["masked"][video_id][clip_id]:
5577                        masked[video_id][clip_id][int(start) : int(end)] = 1
5578                        total_masked += end - start
5579        old_n_frames = sum(
5580            [(vv == 0).sum() for v in masked.values() for vv in v.values()]
5581        )
5582        n_frames = old_n_frames + n_frames
5583        factor = 1
5584        while (
5585            sum([(vv == 0).sum() for v in masked.values() for vv in v.values()])
5586            < n_frames
5587        ):
5588            threshold = float(
5589                torch.mean(
5590                    torch.tensor(
5591                        [
5592                            torch.mean(
5593                                torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
5594                            )
5595                            for x in score_tensors.values()
5596                        ]
5597                    )
5598                )
5599            )
5600            threshold = threshold * factor
5601            intervals = []
5602            interval_scores = []
5603            key1 = list(score_tensors.keys())[0]
5604            key2 = list(score_tensors[key1].keys())[0]
5605            num_scores = score_tensors[key1][key2].shape[0]
5606            for i in range(num_scores):
5607                v_dict = task.dataset("train").find_valleys(
5608                    predicted=score_tensors,
5609                    threshold=threshold,
5610                    min_frames=min_length,
5611                    main_class=i,
5612                    low=False,
5613                )
5614                for v_id, interval_list in v_dict.items():
5615                    intervals += [x + [v_id] for x in interval_list]
5616                    interval_scores += [
5617                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
5618                        for start, end, clip_id in interval_list
5619                    ]
5620            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
5621            i = 0
5622            while sum(
5623                [(vv == 0).sum() for v in masked.values() for vv in v.values()]
5624            ) < n_frames and i < len(intervals):
5625                start, end, clip_id, video_id = intervals[i]
5626                i += 1
5627                masked[video_id][clip_id][int(start) : int(end)] = 0
5628            factor *= 0.9
5629            if factor < 0.05:
5630                warnings.warn(f"Could not find enough frames!")
5631                break
5632        mask["masked"] = {video_id: {} for video_id in video_intervals}
5633        total_masked_new = 0
5634        for video_id in masked:
5635            for clip_id in masked[video_id]:
5636                if (
5637                    video_id in mask["unannotated"]
5638                    and clip_id in mask["unannotated"][video_id]
5639                ):
5640                    for start, end in mask["unannotated"][video_id][clip_id]:
5641                        masked[video_id][clip_id][int(start) : int(end)] = 0
5642                if (
5643                    video_id in mask["val_intervals"]
5644                    and clip_id in mask["val_intervals"][video_id]
5645                ):
5646                    for start, end in mask["val_intervals"][video_id][clip_id]:
5647                        masked[video_id][clip_id][int(start) : int(end)] = 0
5648                indices = np.where(masked[video_id][clip_id].numpy())[0]
5649                mask["masked"][video_id][clip_id] = self._get_intervals(indices)
5650        for video_id in mask["masked"]:
5651            for clip_id in mask["masked"][video_id]:
5652                for start, end in mask["masked"][video_id][clip_id]:
5653                    total_masked_new += end - start
5654        self._save_mask(mask, mask_name)
5655        with open(
5656            os.path.join(
5657                self.project_path, "results", f"{mask_name}.txt", encoding="utf-8"
5658            ),
5659            "a",
5660        ) as f:
5661            f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n")
5662        print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}")
5663
5664    def _visualize_results_label(
5665        self,
5666        episode_name: str,
5667        label: str,
5668        load_epoch: int = None,
5669        parameters_update: Dict = None,
5670        add_legend: bool = True,
5671        ground_truth: bool = True,
5672        hide_axes: bool = False,
5673        width: float = 10,
5674        whole_video: bool = False,
5675        transparent: bool = False,
5676        num_plots: int = 1,
5677        smooth_interval: int = 0,
5678    ):
5679        other_path = os.path.join(self.project_path, "results", "other")
5680        if not os.path.exists(other_path):
5681            os.mkdir(other_path)
5682        if parameters_update is None:
5683            parameters_update = {}
5684        if "model" in parameters_update.keys():
5685            raise ValueError("Cannot change model parameters after training!")
5686        task, parameters, _ = self._make_task_prediction(
5687            "_",
5688            load_episode=episode_name,
5689            parameters_update=parameters_update,
5690            load_epoch=load_epoch,
5691            mode="val",
5692        )
5693        for i in range(num_plots):
5694            print(i)
5695            task._visualize_results_label(
5696                smooth_interval=smooth_interval,
5697                label=label,
5698                save_path=os.path.join(
5699                    other_path, f"{episode_name}_prediction_{i}.jpg"
5700                ),
5701                add_legend=add_legend,
5702                ground_truth=ground_truth,
5703                hide_axes=hide_axes,
5704                whole_video=whole_video,
5705                transparent=transparent,
5706                dataset="val",
5707                width=width,
5708                title=str(i),
5709            )
5710
5711    def plot_confusion_matrix(
5712        self,
5713        episode_name: str,
5714        load_epoch: int = None,
5715        parameters_update: Dict = None,
5716        metric: str = "recall",
5717        mode: str = "val",
5718        remove_saved_features: bool = False,
5719        save_path: str = None,
5720        cmap: str = "viridis",
5721    ) -> Tuple[ndarray, Iterable]:
5722        """Make a confusion matrix plot and return the data.
5723
5724        If the annotation is non-exclusive, only false positive labels are considered.
5725
5726        Parameters
5727        ----------
5728        episode_name : str
5729            the name of the episode to load
5730        load_epoch : int, optional
5731            the index of the epoch to load (by default the last one is loaded)
5732        parameters_update : dict, optional
5733            a dictionary of parameter updates (only for "data" and "general" categories)
5734        metric : {"recall", "precision"}
5735            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
5736            into account, and if `type` is `"precision"`, only false negatives
5737        mode : {'val', 'all', 'test', 'train'}
5738            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
5739        remove_saved_features : bool, default False
5740            if `True`, the dataset that is used for computation is then deleted
5741
5742        Returns
5743        -------
5744        confusion_matrix : np.ndarray
5745            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
5746            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
5747            `N_i` is the number of frames that have the i-th label in the ground truth
5748        classes : list
5749            a list of labels
5750
5751        """
5752        task, parameters, mode = self._make_task_prediction(
5753            "_",
5754            load_episode=episode_name,
5755            load_epoch=load_epoch,
5756            parameters_update=parameters_update,
5757            mode=mode,
5758        )
5759        dataset = task.dataset(mode)
5760        prediction = task.predict(dataset, raw_output=True)
5761        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
5762        if remove_saved_features:
5763            self._remove_stores(parameters)
5764        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
5765        ax.imshow(confusion_matrix, cmap=cmap)
5766        # Show all ticks and label them with the respective list entries
5767        ax.set_xticks(np.arange(len(classes)))
5768        ax.set_xticklabels(classes)
5769        ax.set_yticks(np.arange(len(classes)))
5770        ax.set_yticklabels(classes)
5771        # Rotate the tick labels and set their alignment.
5772        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
5773        # Loop over data dimensions and create text annotations.
5774        for i in range(len(classes)):
5775            for j in range(len(classes)):
5776                ax.text(
5777                    j,
5778                    i,
5779                    np.round(confusion_matrix[i, j], 2),
5780                    ha="center",
5781                    va="center",
5782                    color="w",
5783                )
5784        if metric is not None:
5785            ax.set_title(f"{metric} {episode_name}")
5786        else:
5787            ax.set_title(episode_name)
5788        fig.tight_layout()
5789        if save_path is None:
5790            plt.show()
5791        else:
5792            plt.savefig(save_path)
5793            plt.close()
5794        return confusion_matrix, classes
5795
5796    def _plot_ethograms_gt_pred(
5797        self,
5798        data_gt: dict,
5799        data_pred: dict,
5800        labels_gt: list,
5801        labels_pred: list,
5802        start: int = 0,
5803        end: int = -1,
5804        cmap_pred: str = "binary",
5805        cmap_gt: str = "binary",
5806        save: str = None,
5807        fontsize=22,
5808        time_mode="frames",
5809        fps: int = None,
5810    ) -> None:
5811        """Plot ethograms from start to end time (in frames), mode can be prediction or ground truth depending on the data format."""
5812        # print(data.keys())
5813        best_pred = (
5814            data_pred[list(data_pred.keys())[0]].numpy() > 0.5
5815        )  # Threshold the predictions
5816        data_gt = binarize_data(data_gt, max_frame=end)
5817
5818        # Crop data to min length
5819        if end < 0:
5820            end = min(data_gt.shape[1], best_pred.shape[1])
5821        data_gt = data_gt[:, :end]
5822        best_pred = best_pred[:, :end]
5823
5824        # Reorder behaviors
5825        ind_gt = []
5826        ind_pred = []
5827        labels_pred = [labels_pred[i] for i in range(len(labels_pred))]
5828        labels_pred = np.roll(
5829            labels_pred, 1
5830        ).tolist()  
5831        check_gt = np.where(np.sum(data_gt, axis=1) > 0)[0]
5832        check_pred = np.where(np.sum(best_pred, axis=1) > 0)[0]
5833        for k, gt_beh in enumerate(labels_gt):
5834            if gt_beh in labels_pred:
5835                j = labels_pred.index(gt_beh)
5836                if not k in check_gt and not j in check_pred:
5837                    continue
5838                ind_gt.append(labels_gt.index(gt_beh))
5839                ind_pred.append(j)
5840        # Create label list
5841        labels = np.array(labels_gt)[ind_gt]
5842        assert (labels == np.array(labels_pred)[ind_pred]).all()
5843
5844        # # Create image
5845        image_pred = best_pred[ind_pred].astype(float)
5846        image_gt = data_gt[ind_gt]
5847
5848        f, axs = plt.subplots(
5849            len(labels), 1, figsize=(5 * len(labels), 15), sharex=True
5850        )
5851        end = image_gt.shape[1] if end < 0 else end
5852        for i, (ax, label) in enumerate(zip(axs, labels)):
5853
5854            im1 = np.array([image_gt[i], np.ones_like(image_gt[i]) * (-1)])
5855            im1 = np.ma.masked_array(im1, im1 < 0)
5856
5857            im2 = np.array([np.ones_like(image_pred[i]) * (-1), image_pred[i]])
5858            im2 = np.ma.masked_array(im2, im2 < 0)
5859
5860            ax.imshow(im1, aspect="auto", cmap=cmap_gt, interpolation="nearest")
5861            ax.imshow(im2, aspect="auto", cmap=cmap_pred, interpolation="nearest")
5862
5863            ax.set_yticks(np.arange(2), ["GT", "Pred"], fontsize=fontsize)
5864            ax.tick_params(axis="x", labelsize=fontsize)
5865            ax.set_ylabel(label, fontsize=fontsize)
5866            if time_mode == "frames":
5867                ax.set_xlabel("Num Frames", fontsize=fontsize)
5868            elif time_mode == "seconds":
5869                assert not fps is None, "Please provide fps"
5870                ax.set_xlabel("Time (s)", fontsize=fontsize)
5871                ax.set_xticks(
5872                    np.linspace(0, end, 10),
5873                    np.linspace(0, end / fps, 10).astype(np.int32),
5874                )
5875
5876            ax.set_xlim(start, end)
5877
5878        if save is None:
5879            plt.show()
5880        else:
5881            plt.savefig(save)
5882            plt.close()
5883
5884    def plot_ethograms(
5885        self,
5886        episode_name: str,
5887        prediction_name: str,
5888        start: int = 0,
5889        end: int = -1,
5890        save_path: str = None,
5891        cmap_pred: str = "binary",
5892        cmap_gt: str = "binary",
5893        fontsize: int = 22,
5894        time_mode: str = "frames",
5895        fps: int = None,
5896    ):
5897        """Plot ethograms from start to end time (in frames) for ground truth and prediction"""
5898        params = self._read_parameters(catch_blanks=False)
5899        parameters = self._get_data_pars(
5900            params,
5901        )
5902        if not save_path is None:
5903            os.makedirs(save_path, exist_ok=True)
5904        gt_files = [
5905            f for f in self.data_path if f.endswith(parameters["annotation_suffix"])
5906        ]
5907        pred_path = os.path.join(
5908            self.project_path, "results", "predictions", prediction_name
5909        )
5910        pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)]
5911        for pred_path in pred_paths:
5912            predictions = load_pickle(pred_path)
5913            behaviors = self.get_behavior_dictionary(episode_name)
5914            gt_filename = os.path.basename(pred_path).replace(
5915                "_".join(["_" + prediction_name, "prediction.pickle"]),
5916                parameters["annotation_suffix"],
5917            )
5918            if os.path.exists(os.path.join(self.data_path, gt_filename)):
5919                gt_data = load_pickle(os.path.join(self.data_path, gt_filename))
5920
5921                self._plot_ethograms_gt_pred(
5922                    gt_data,
5923                    predictions,
5924                    gt_data[1],
5925                    behaviors,
5926                    start=start,
5927                    end=end,
5928                    save=os.path.join(
5929                        save_path,
5930                        os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred",
5931                    ),
5932                    cmap_pred=cmap_pred,
5933                    cmap_gt=cmap_gt,
5934                    fontsize=fontsize,
5935                    time_mode=time_mode,
5936                    fps=fps,
5937                )
5938            else:
5939                print("GT file not found")
5940
5941    def _create_side_panel(self, height, width, labels_pred, preds, labels_gt, gt=None):
5942        """Create a side panel for video annotation display.
5943
5944        Parameters
5945        ----------
5946        height : int
5947            the height of the panel
5948        width : int
5949            the width of the panel
5950        labels_pred : list
5951            the list of predicted behavior labels
5952        preds : array-like
5953            the prediction values for each behavior
5954        labels_gt : list
5955            the list of ground truth behavior labels
5956        gt : array-like, optional
5957            the ground truth values for each behavior
5958
5959        Returns
5960        -------
5961        side_panel : np.ndarray
5962            the created side panel as an image array
5963
5964        """
5965        side_panel = np.ones((height, int(width / 4), 3), dtype=np.uint8) * 255
5966
5967        beh_indices = np.where(preds)[0]
5968        for i, label in enumerate(labels_pred):
5969            color = (0, 0, 0)
5970            if i in beh_indices:
5971                color = (0, 255, 0)
5972            cv2.putText(
5973                side_panel,
5974                label,
5975                (10, 50 + 50 * i),
5976                cv2.FONT_HERSHEY_SIMPLEX,
5977                1,
5978                color,
5979                2,
5980                cv2.LINE_AA,
5981            )
5982        if gt is not None:
5983            beh_indices_gt = np.where(gt)[0]
5984            for i, label in enumerate(labels_gt):
5985                color = (0, 0, 0)
5986                if i in beh_indices_gt:
5987                    color = (0, 255, 0)
5988                cv2.putText(
5989                    side_panel,
5990                    label,
5991                    (10, 50 + 50 * i + 80 * len(labels_pred)),
5992                    cv2.FONT_HERSHEY_SIMPLEX,
5993                    1,
5994                    color,
5995                    2,
5996                    cv2.LINE_AA,
5997                )
5998        return side_panel
5999
6000    def create_annotated_video(
6001        self,
6002        prediction_file_paths: list,
6003        video_file_paths: list,
6004        episode_name: str,  # To get the list of behaviors
6005        ground_truth_file_paths: list = None,
6006        pred_thresh: float = 0.5,
6007        start: int = 0,
6008        end: int = -1,
6009    ):
6010        """Create a video with the predictions overlaid on the video"""
6011        for k, (pred_path, vid_path) in enumerate(
6012            zip(prediction_file_paths, video_file_paths)
6013        ):
6014            print("Generating video for :", os.path.basename(vid_path))
6015            predictions = load_pickle(pred_path)
6016            best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh
6017            behaviors = self.get_behavior_dictionary(episode_name)
6018            # Load video
6019            labels_pred = [behaviors[i] for i in range(len(behaviors))]
6020            labels_pred = np.roll(
6021                labels_pred, 1
6022            ).tolist() 
6023
6024            gt_data = None
6025            if ground_truth_file_paths is not None:
6026                gt_data = load_pickle(ground_truth_file_paths[k])
6027                labels_gt = gt_data[1]
6028                gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1])
6029
6030            cap = cv2.VideoCapture(vid_path)
6031            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6032            end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end
6033            fps = cap.get(cv2.CAP_PROP_FPS)
6034            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
6035            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
6036            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
6037            out = cv2.VideoWriter(
6038                os.path.join(
6039                    os.path.dirname(vid_path),
6040                    os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4",
6041                ),
6042                fourcc,
6043                fps,
6044                # (width + int(width/4) , height),
6045                (600, 300),
6046            )
6047            count = 0
6048            bar = tqdm(total=end - start)
6049            while cap.isOpened():
6050                ret, frame = cap.read()
6051                if not ret:
6052                    break
6053
6054                side_panel = self._create_side_panel(
6055                    height,
6056                    width,
6057                    labels_pred,
6058                    best_pred[:, count],
6059                    labels_gt,
6060                    gt_data[:, count],
6061                )
6062                frame = np.concatenate((frame, side_panel), axis=1)
6063                frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25)
6064                out.write(frame)
6065                count += 1
6066                bar.update(1)
6067
6068                if count > end:
6069                    break
6070
6071            cap.release()
6072            out.release()
6073            cv2.destroyAllWindows()
6074
6075    def plot_predictions(
6076        self,
6077        episode_name: str,
6078        load_epoch: int = None,
6079        parameters_update: Dict = None,
6080        add_legend: bool = True,
6081        ground_truth: bool = True,
6082        colormap: str = "dlc2action",
6083        hide_axes: bool = False,
6084        min_classes: int = 1,
6085        width: float = 10,
6086        whole_video: bool = False,
6087        transparent: bool = False,
6088        drop_classes: Set = None,
6089        search_classes: Set = None,
6090        num_plots: int = 1,
6091        remove_saved_features: bool = False,
6092        smooth_interval_prediction: int = 0,
6093        data_path: str = None,
6094        file_paths: Set = None,
6095        mode: str = "val",
6096        font_size: float = None,
6097        window_size: int = 400,
6098    ) -> None:
6099        """Visualize random predictions.
6100
6101        Parameters
6102        ----------
6103        episode_name : str
6104            the name of the episode to load
6105        load_epoch : int, optional
6106            the epoch to load (by default last)
6107        parameters_update : dict, optional
6108            parameter update dictionary
6109        add_legend : bool, default True
6110            if True, legend will be added to the plot
6111        ground_truth : bool, default True
6112            if True, ground truth will be added to the plot
6113        colormap : str, default 'Accent'
6114            the `matplotlib` colormap to use
6115        hide_axes : bool, default True
6116            if `True`, the axes will be hidden on the plot
6117        min_classes : int, default 1
6118            the minimum number of classes in a displayed interval
6119        width : float, default 10
6120            the width of the plot
6121        whole_video : bool, default False
6122            if `True`, whole videos are plotted instead of segments
6123        transparent : bool, default False
6124            if `True`, the background on the plot is transparent
6125        drop_classes : set, optional
6126            a set of class names to not be displayed
6127        search_classes : set, optional
6128            if given, only intervals where at least one of the classes is in ground truth will be shown
6129        num_plots : int, default 1
6130            the number of plots to make
6131        remove_saved_features : bool, default False
6132            if `True`, the dataset will be deleted after computation
6133        smooth_interval_prediction : int, default 0
6134            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
6135        data_path : str, optional
6136            the data path to run the prediction for
6137        file_paths : set, optional
6138            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
6139            for
6140        mode : {'all', 'test', 'val', 'train'}
6141            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
6142
6143        """
6144        plot_path = os.path.join(self.project_path, "results", "plots")
6145        task, parameters, mode = self._make_task_prediction(
6146            "_",
6147            load_episode=episode_name,
6148            parameters_update=parameters_update,
6149            load_epoch=load_epoch,
6150            data_path=data_path,
6151            file_paths=file_paths,
6152            mode=mode,
6153        )
6154        os.makedirs(plot_path, exist_ok=True)
6155        task.visualize_results(
6156            save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"),
6157            add_legend=add_legend,
6158            ground_truth=ground_truth,
6159            colormap=colormap,
6160            hide_axes=hide_axes,
6161            min_classes=min_classes,
6162            whole_video=whole_video,
6163            transparent=transparent,
6164            dataset=mode,
6165            drop_classes=drop_classes,
6166            search_classes=search_classes,
6167            width=width,
6168            smooth_interval_prediction=smooth_interval_prediction,
6169            font_size=font_size,
6170            num_plots=num_plots,
6171            window_size=window_size,
6172        )
6173        if remove_saved_features:
6174            self._remove_stores(parameters)
6175
6176    def create_video_from_labels(
6177        self,
6178        video_dir_path: str,
6179        mode="ground_truth",
6180        prediction_name: str = None,
6181        save_path: str = None,
6182    ):
6183        if save_path is None:
6184            save_path = os.path.join(
6185                self.project_path, "results", f"annotated_videos_from_{mode}"
6186            )
6187        os.makedirs(save_path, exist_ok=True)
6188
6189        params = self._read_parameters(catch_blanks=False)
6190
6191        if mode == "ground_truth":
6192            source_dir = self.annotation_path
6193            annotation_suffix = params["data"]["annotation_suffix"]
6194        elif mode == "prediction":
6195            assert (
6196                not prediction_name is None
6197            ), "Please provide a prediction name with mode 'prediction'"
6198            source_dir = os.path.join(
6199                self.project_path, "results", "predictions", prediction_name
6200            )
6201            annotation_suffix = f"_{prediction_name}_prediction.pickle"
6202
6203        video_annotation_pairs = [
6204            (
6205                os.path.join(video_dir_path, f),
6206                os.path.join(
6207                    source_dir, f.replace(f.split(".")[-1], annotation_suffix)
6208                ),
6209            )
6210            for f in os.listdir(video_dir_path)
6211            if os.path.exists(
6212                os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix))
6213            )
6214        ]
6215
6216        for video_file, annotation_file in tqdm(video_annotation_pairs):
6217            if not os.path.exists(video_file):
6218                print(f"Video file {video_file} does not exist, skipping.")
6219                continue
6220            if not os.path.exists(annotation_file):
6221                print(f"Annotation file {annotation_file} does not exist, skipping.")
6222                continue
6223
6224            if annotation_file.endswith(".pickle"):
6225                annotations = load_pickle(annotation_file)
6226            elif annotation_file.endswith(".csv"):
6227                annotations = pd.read_csv(annotation_file)
6228
6229            if mode == "ground_truth":
6230                behaviors = annotations[1]
6231                annot_data = annotations[3]
6232            elif mode == "predictions":
6233                behaviors = list(annotations["classes"].values())
6234                annot_data = [
6235                    annotations[key]
6236                    for key in annotations.keys()
6237                    if key not in ["classes", "min_frame", "max_frame"]
6238                ]
6239                if params["general"]["exclusive"]:
6240                    annot_data = [np.argmax(annot, axis=1) for annot in annot_data]
6241                    seqs = [
6242                        [
6243                            self._bin_array_to_sequences(annot, target_value=k)
6244                            for k in range(len(behaviors))
6245                        ]
6246                        for annot in annot_data
6247                    ]
6248                else:
6249                    annot_data = [np.where(annot > 0.5)[0] for annot in annot_data]
6250                    seqs = [
6251                        self._bin_array_to_sequences(annot, target_value=1)
6252                        for annot in annot_data
6253                    ]
6254                annotations = ["", "", seqs]
6255
6256            for individual in annotations[3]:
6257                for behavior in annotations[3][individual]:
6258                    intervals = annotations[3][individual][behavior]
6259                    self._extract_videos(
6260                        video_file,
6261                        intervals,
6262                        behavior,
6263                        individual,
6264                        save_path,
6265                        resolution=(640, 480),
6266                        fps=30,
6267                    )
6268
6269    def _bin_array_to_sequences(
6270        self, annot_data: List[np.ndarray], target_value: int
6271    ) -> List[List[Tuple[int, int]]]:
6272        is_target = annot_data == target_value
6273        changes = np.diff(np.concatenate(([False], is_target, [False])))
6274        indices = np.where(changes)[0].reshape(-1, 2)
6275        subsequences = [list(range(start, end)) for start, end in indices]
6276        return subsequences
6277
6278    def _extract_videos(
6279        self,
6280        video_file: str,
6281        intervals: np.ndarray,
6282        behavior: str,
6283        individual: str,
6284        video_dir: str,
6285        resolution: Tuple[int, int] = (640, 480),
6286        fps: int = 30,
6287    ) -> None:
6288        """Extract frames from a video file from frames in between intervals in behavior folder for a given individual"""
6289        cap = cv2.VideoCapture(video_file)
6290        print("Extracting frames from", video_file)
6291
6292        for start, end, confusion in tqdm(intervals):
6293
6294            frame_count = start
6295            assert start < end, "Start frame should be less than end frame"
6296            if confusion > 0.5:
6297                continue
6298            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6299            output_file = os.path.join(
6300                video_dir,
6301                individual,
6302                behavior,
6303                os.path.splitext(os.path.basename(video_file))[0]
6304                + f"vid_{individual}_{behavior}_{start:05d}_{end:05d}.mp4",
6305            )
6306            fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # Codec, e.g., 'XVID', 'MJPG'
6307            out = cv2.VideoWriter(
6308                output_file, fourcc, fps, (resolution[0], resolution[1])
6309            )
6310            while cap.isOpened():
6311                ret, frame = cap.read()
6312                if not ret:
6313                    break
6314
6315                # Resize large frames
6316                frame = cv2.resize(frame, (640, 480))
6317                out.write(frame)
6318
6319                frame_count += 1
6320                # Break if end frame is reached or max frames per behavior is reached
6321                if frame_count == end:
6322                    break
6323            if frame_count <= 2:
6324                os.remove(output_file)
6325            # cap.release()
6326            out.release()
6327
6328    def create_metadata_backup(self) -> None:
6329        """Create a copy of the meta files."""
6330        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6331        meta_path = os.path.join(self.project_path, "meta")
6332        if os.path.exists(meta_copy_path):
6333            shutil.rmtree(meta_copy_path)
6334        os.mkdir(meta_copy_path)
6335        for file in os.listdir(meta_path):
6336            if file == "backup":
6337                continue
6338            if os.path.isdir(os.path.join(meta_path, file)):
6339                continue
6340            shutil.copy(
6341                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
6342            )
6343
6344    def load_metadata_backup(self) -> None:
6345        """Load from previously created meta data backup (in case of corruption)."""
6346        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6347        meta_path = os.path.join(self.project_path, "meta")
6348        for file in os.listdir(meta_copy_path):
6349            shutil.copy(
6350                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
6351            )
6352
6353    def get_behavior_dictionary(self, episode_name: str) -> Dict:
6354        """Get the behavior dictionary for an episode.
6355
6356        Parameters
6357        ----------
6358        episode_name : str
6359            the name of the episode
6360
6361        Returns
6362        -------
6363        behaviors_dictionary : dict
6364            a dictionary where keys are label indices and values are label names
6365
6366        """
6367        return self._episode(episode_name).get_behaviors_dict()
6368
6369    def import_episodes(
6370        self,
6371        episodes_directory: str,
6372        name_map: Dict = None,
6373        repeat_policy: str = "error",
6374    ) -> None:
6375        """Import episodes exported with `Project.export_episodes`.
6376
6377        Parameters
6378        ----------
6379        episodes_directory : str
6380            the path to the exported episodes directory
6381        name_map : dict, optional
6382            a name change dictionary for the episodes: keys are old names, values are new names
6383        repeat_policy : {'error', 'skip', 'force'}, default 'error'
6384            the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates,
6385            'force' overwrites existing episodes
6386
6387        """
6388        if name_map is None:
6389            name_map = {}
6390        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
6391        to_remove = []
6392        import_string = "Imported episodes: "
6393        for episode_name in episodes.index:
6394            if episode_name in name_map:
6395                import_string += f"{episode_name} "
6396                episode_name = name_map[episode_name]
6397                import_string += f"({episode_name}), "
6398            else:
6399                import_string += f"{episode_name}, "
6400            try:
6401                self._check_episode_validity(episode_name, allow_doublecolon=True)
6402            except ValueError as e:
6403                if str(e).endswith("is already taken!"):
6404                    if repeat_policy == "skip":
6405                        to_remove.append(episode_name)
6406                    elif repeat_policy == "force":
6407                        self.remove_episode(episode_name)
6408                    elif repeat_policy == "error":
6409                        raise ValueError(
6410                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
6411                        )
6412                    else:
6413                        raise ValueError(
6414                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']"
6415                        )
6416        episodes = episodes.drop(index=to_remove)
6417        self._episodes().update(
6418            episodes,
6419            name_map=name_map,
6420            force=(repeat_policy == "force"),
6421            data_path=self.data_path,
6422            annotation_path=self.annotation_path,
6423        )
6424        for episode_name in episodes.index:
6425            if episode_name in name_map:
6426                new_episode_name = name_map[episode_name]
6427            else:
6428                new_episode_name = episode_name
6429            model_dir = os.path.join(
6430                self.project_path, "results", "model", new_episode_name
6431            )
6432            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
6433            if os.path.exists(model_dir):
6434                shutil.rmtree(model_dir)
6435            os.mkdir(model_dir)
6436            for file in os.listdir(old_model_dir):
6437                shutil.copyfile(
6438                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
6439                )
6440            log_file = os.path.join(
6441                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6442            )
6443            old_log_file = os.path.join(
6444                episodes_directory, "logs", f"{episode_name}.txt"
6445            )
6446            shutil.copyfile(old_log_file, log_file)
6447        print(import_string)
6448        print("\n")
6449
6450    def export_episodes(
6451        self, episode_names: List, output_directory: str, name: str = None
6452    ) -> None:
6453        """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`.
6454
6455        Parameters
6456        ----------
6457        episode_names : list
6458            a list of string episode names
6459        output_directory : str
6460            the path to the directory where the episodes will be saved
6461        name : str, optional
6462            the name of the episodes directory (by default `exported_episodes`)
6463
6464        """
6465        if name is None:
6466            name = "exported_episodes"
6467        if os.path.exists(
6468            os.path.join(output_directory, name + ".zip")
6469        ) or os.path.exists(os.path.join(output_directory, name)):
6470            i = 1
6471            while os.path.exists(
6472                os.path.join(output_directory, name + f"_{i}.zip")
6473            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
6474                i += 1
6475            name = name + f"_{i}"
6476        dest_dir = os.path.join(output_directory, name)
6477        os.mkdir(dest_dir)
6478        os.mkdir(os.path.join(dest_dir, "model"))
6479        os.mkdir(os.path.join(dest_dir, "logs"))
6480        runs = []
6481        for episode in episode_names:
6482            runs += self._episodes().get_runs(episode)
6483        for run in runs:
6484            shutil.copytree(
6485                os.path.join(self.project_path, "results", "model", run),
6486                os.path.join(dest_dir, "model", run),
6487            )
6488            shutil.copyfile(
6489                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
6490                os.path.join(dest_dir, "logs", f"{run}.txt"),
6491            )
6492        data = self._episodes().get_subset(runs)
6493        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))
6494
6495    def get_results_table(
6496        self,
6497        episode_names: List,
6498        metrics: List = None,
6499        mode: str = "mean",  # Choose between ["mean", "statistics", "detail"]
6500        print_results: bool = True,
6501        classes: List = None,
6502    ):
6503        """Generate a `pandas` dataframe with a summary of episode results.
6504
6505        Parameters
6506        ----------
6507        episode_names : list
6508            a list of names of episodes to include
6509        metrics : list, optional
6510            a list of metric names to include
6511        mode : bool, optional
6512            the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean"
6513        print_results : bool, optional
6514            if True, the results will be printed to the console, by default True
6515        classes : list, optional
6516            a list of names of classes to include (by default all are included)
6517
6518        Returns
6519        -------
6520        results : pd.DataFrame
6521            a table with the results
6522
6523        """
6524        run_names = []
6525        for episode in episode_names:
6526            run_names += self._episodes().get_runs(episode)
6527        episodes = self.list_episodes(run_names, print_results=False)
6528        metric_columns = [x for x in episodes.columns if x[0] == "results"]
6529        results_df = pd.DataFrame()
6530        if metrics is not None:
6531            metric_columns = [
6532                x for x in metric_columns if x[1].split("_")[0] in metrics
6533            ]
6534        for episode in episode_names:
6535            results = []
6536            metric_set = set()
6537            for run in self._episodes().get_runs(episode):
6538                beh_dict = self.get_behavior_dictionary(run)
6539                res_dict = defaultdict(lambda: {})
6540                for column in metric_columns:
6541                    if np.isnan(episodes.loc[run, column]):
6542                        continue
6543                    split = column[1].split("_")
6544                    if split[-1].isnumeric():
6545                        beh_ind = int(split[-1])
6546                        metric_name = "_".join(split[:-1])
6547                        beh = beh_dict[beh_ind]
6548                    else:
6549                        beh = "average"
6550                        metric_name = column[1]
6551                    res_dict[beh][metric_name] = episodes.loc[run, column]
6552                    metric_set.add(metric_name)
6553                if "average" not in res_dict:
6554                    res_dict["average"] = {}
6555                for metric in metric_set:
6556                    if metric not in res_dict["average"]:
6557                        arr = [
6558                            res_dict[beh][metric]
6559                            for beh in res_dict
6560                            if metric in res_dict[beh]
6561                        ]
6562                        res_dict["average"][metric] = np.mean(arr)
6563                results.append(res_dict)
6564            episode_results = {}
6565            for metric in metric_set:
6566                for beh in results[0].keys():
6567                    if classes is not None and beh not in classes:
6568                        continue
6569                    arr = []
6570                    for res_dict in results:
6571                        if metric in res_dict[beh]:
6572                            arr.append(res_dict[beh][metric])
6573                    if len(arr) > 0:
6574                        if mode == "statistics":
6575                            episode_results[(beh, f"{episode} {metric} mean")] = (
6576                                np.mean(arr)
6577                            )
6578                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
6579                                arr
6580                            )
6581                        elif mode == "mean":
6582                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
6583                        elif mode == "detail":
6584                            for i, val in enumerate(arr):
6585                                episode_results[(beh, f"{episode}::{i} {metric}")] = val
6586            for key, value in episode_results.items():
6587                results_df.loc[key[0], key[1]] = value
6588        if print_results:
6589            print(f"RESULTS:")
6590            print(results_df)
6591            print("\n")
6592        return results_df
6593
6594    def episode_exists(self, episode_name: str) -> bool:
6595        """Check if an episode already exists.
6596
6597        Parameters
6598        ----------
6599        episode_name : str
6600            the episode name
6601
6602        Returns
6603        -------
6604        exists : bool
6605            `True` if the episode exists
6606
6607        """
6608        return self._episodes().check_name_validity(episode_name)
6609
6610    def search_exists(self, search_name: str) -> bool:
6611        """Check if a search already exists.
6612
6613        Parameters
6614        ----------
6615        search_name : str
6616            the search name
6617
6618        Returns
6619        -------
6620        exists : bool
6621            `True` if the search exists
6622
6623        """
6624        return self._searches().check_name_validity(search_name)
6625
6626    def prediction_exists(self, prediction_name: str) -> bool:
6627        """Check if a prediction already exists.
6628
6629        Parameters
6630        ----------
6631        prediction_name : str
6632            the prediction name
6633
6634        Returns
6635        -------
6636        exists : bool
6637            `True` if the prediction exists
6638
6639        """
6640        return self._predictions().check_name_validity(prediction_name)
6641
6642    @staticmethod
6643    def project_name_available(projects_path: str, project_name: str):
6644        """Check if a project name is available.
6645
6646        Parameters
6647        ----------
6648        projects_path : str
6649            the path to the projects directory
6650        project_name : str
6651            the name of the project to check
6652
6653        Returns
6654        -------
6655        available : bool
6656            `True` if the project name is available
6657
6658        """
6659        if projects_path is None:
6660            projects_path = os.path.join(str(Path.home()), "DLC2Action")
6661        return not os.path.exists(os.path.join(projects_path, project_name))
6662
6663    def _update_episode_metrics(self, episode_name: str, metrics: Dict):
6664        """Update meta data with evaluation results.
6665
6666        Parameters
6667        ----------
6668        episode_name : str
6669            the name of the episode
6670        metrics : dict
6671            the metrics dictionary to update with
6672
6673        """
6674        self._episodes().update_episode_metrics(episode_name, metrics)
6675
6676    def rename_episode(self, episode_name: str, new_episode_name: str):
6677        """Rename an episode.
6678
6679        Parameters
6680        ----------
6681        episode_name : str
6682            the current episode name
6683        new_episode_name : str
6684            the new episode name
6685
6686        """
6687        shutil.move(
6688            os.path.join(self.project_path, "results", "model", episode_name),
6689            os.path.join(self.project_path, "results", "model", new_episode_name),
6690        )
6691        shutil.move(
6692            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
6693            os.path.join(
6694                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6695            ),
6696        )
6697        self._episodes().rename_episode(episode_name, new_episode_name)

A class to create and maintain the project files + keep track of experiments.

Project( name: str, data_type: str = None, annotation_type: str = 'none', projects_path: str = None, data_path: Union[str, List] = None, annotation_path: Union[str, List] = None, copy: bool = False)
 58    def __init__(
 59        self,
 60        name: str,
 61        data_type: str = None,
 62        annotation_type: str = "none",
 63        projects_path: str = None,
 64        data_path: Union[str, List] = None,
 65        annotation_path: Union[str, List] = None,
 66        copy: bool = False,
 67    ) -> None:
 68        """Initialize the class.
 69
 70        Parameters
 71        ----------
 72        name : str
 73            name of the project
 74        data_type : str, optional
 75            data type (run Project.data_types() to see available options; has to be provided if the project is being
 76            created)
 77        annotation_type : str, default 'none'
 78            annotation type (run Project.annotation_types() to see available options)
 79        projects_path : str, optional
 80            path to the projects folder (is filled with ~/DLC2Action by default)
 81        data_path : str, optional
 82            path to the folder containing input files for the project (has to be provided if the project is being
 83            created)
 84        annotation_path : str, optional
 85            path to the folder containing annotation files for the project
 86        copy : bool, default False
 87            if True, the files from annotation_path and data_path will be copied to the projects folder;
 88            otherwise they will be moved
 89
 90        """
 91        if projects_path is None:
 92            projects_path = os.path.join(str(Path.home()), "DLC2Action")
 93        if not os.path.exists(projects_path):
 94            os.mkdir(projects_path)
 95        self.project_path = os.path.join(projects_path, name)
 96        self.name = name
 97        self.data_type = data_type
 98        self.annotation_type = annotation_type
 99        self.data_path = data_path
100        self.annotation_path = annotation_path
101        if not os.path.exists(self.project_path):
102            if data_type is None:
103                raise ValueError(
104                    "The data_type parameter is necessary when creating a new project!"
105                )
106            self._initialize_project(
107                data_type, annotation_type, data_path, annotation_path, copy
108            )
109        else:
110            self.annotation_type, self.data_type = self._read_types()
111            if data_type != self.data_type and data_type is not None:
112                raise ValueError(
113                    f"The project has already been initialized with data_type={self.data_type}!"
114                )
115            if annotation_type != self.annotation_type and annotation_type != "none":
116                raise ValueError(
117                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
118                )
119            self.annotation_path, data_path = self._read_paths()
120            if self.data_path is None:
121                self.data_path = data_path
122            # if data_path != self.data_path and data_path is not None:
123            #     raise ValueError(
124            #         f"The project has already been initialized with data_path={self.data_path}!"
125            #     )
126            if annotation_path != self.annotation_path and annotation_path is not None:
127                raise ValueError(
128                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
129                )
130        self._update_configs()

Initialize the class.

Parameters

name : str name of the project data_type : str, optional data type (run Project.data_types() to see available options; has to be provided if the project is being created) annotation_type : str, default 'none' annotation type (run Project.annotation_types() to see available options) projects_path : str, optional path to the projects folder (is filled with ~/DLC2Action by default) data_path : str, optional path to the folder containing input files for the project (has to be provided if the project is being created) annotation_path : str, optional path to the folder containing annotation files for the project copy : bool, default False if True, the files from annotation_path and data_path will be copied to the projects folder; otherwise they will be moved

project_path
name
data_type
annotation_type
data_path
annotation_path
def get_decision_thresholds( self, episode_names: List, metric_name: str = 'f1', parameters_update: Dict = None, load_epochs: List = None, remove_saved_features: bool = False) -> Tuple[List, List, dlc2action.task.task_dispatcher.TaskDispatcher]:
563    def get_decision_thresholds(
564        self,
565        episode_names: List,
566        metric_name: str = "f1",
567        parameters_update: Dict = None,
568        load_epochs: List = None,
569        remove_saved_features: bool = False,
570    ) -> Tuple[List, List, TaskDispatcher]:
571        """Compute optimal decision thresholds or load them if they have been computed before.
572
573        Parameters
574        ----------
575        episode_names : List
576            a list of episode names
577        metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"}
578            the metric to optimize
579        parameters_update : dict, optional
580            the parameter update dictionary
581        load_epochs : list, optional
582            a list of epochs to load (by default last are loaded)
583        remove_saved_features : bool, default False
584            if `True`, the dataset will be deleted after the computation
585
586        Returns
587        -------
588        thresholds : list
589            a list of float decision threshold values
590        classes : list
591            the label names corresponding to the values
592        task : TaskDispatcher | None
593            the task used in computation
594
595        """
596        parameters = self._make_parameters(
597            "_",
598            episode_names[0],
599            parameters_update,
600            {},
601            load_epochs[0],
602            purpose="prediction",
603        )
604        thresholds = self._thresholds().find_thresholds(
605            episode_names,
606            load_epochs,
607            metric_name,
608            metric_parameters=parameters["metrics"][metric_name],
609        )
610        task = None
611        behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values())
612        return thresholds, behaviors, task

Compute optimal decision thresholds or load them if they have been computed before.

Parameters

episode_names : List a list of episode names metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"} the metric to optimize parameters_update : dict, optional the parameter update dictionary load_epochs : list, optional a list of epochs to load (by default last are loaded) remove_saved_features : bool, default False if True, the dataset will be deleted after the computation

Returns

thresholds : list a list of float decision threshold values classes : list the label names corresponding to the values task : TaskDispatcher | None the task used in computation

def run_episode( self, episode_name: str, load_episode: str = None, parameters_update: Dict = None, task: dlc2action.task.task_dispatcher.TaskDispatcher = None, load_epoch: int = None, load_search: str = None, load_parameters: list = None, round_to_binary: list = None, load_strict: bool = True, n_seeds: int = 1, force: bool = False, suppress_name_check: bool = False, remove_saved_features: bool = False, mask_name: str = None, autostop_metric: str = None, autostop_interval: int = 50, autostop_threshold: float = 0.001, loading_bar: bool = False, trial: Tuple = None) -> dlc2action.task.task_dispatcher.TaskDispatcher:
614    def run_episode(
615        self,
616        episode_name: str,
617        load_episode: str = None,
618        parameters_update: Dict = None,
619        task: TaskDispatcher = None,
620        load_epoch: int = None,
621        load_search: str = None,
622        load_parameters: list = None,
623        round_to_binary: list = None,
624        load_strict: bool = True,
625        n_seeds: int = 1,
626        force: bool = False,
627        suppress_name_check: bool = False,
628        remove_saved_features: bool = False,
629        mask_name: str = None,
630        autostop_metric: str = None,
631        autostop_interval: int = 50,
632        autostop_threshold: float = 0.001,
633        loading_bar: bool = False,
634        trial: Tuple = None,
635    ) -> TaskDispatcher:
636        """Run an episode.
637
638        The task parameters are read from the config files and then updated with the
639        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
640        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
641        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
642        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
643        data parameters are used.
644
645        You can use the autostop parameters to finish training when the parameters are not improving. It will be
646        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
647        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
648        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
649
650        Parameters
651        ----------
652        episode_name : str
653            the episode name
654        load_episode : str, optional
655            the (previously run) episode name to load the model from; if the episode has multiple runs,
656            the new episode will have the same number of runs, each starting with one of the pre-trained models
657        parameters_update : dict, optional
658            the dictionary used to update the parameters from the config files
659        task : TaskDispatcher, optional
660            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
661            a new instance)
662        load_epoch : int, optional
663            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
664        load_search : str, optional
665            the hyperparameter search result to load
666        load_parameters : list, optional
667            a list of string names of the parameters to load from load_search (if not provided, all parameters
668            are loaded)
669        round_to_binary : list, optional
670            a list of string names of the loaded parameters that should be rounded to the nearest power of two
671        load_strict : bool, default True
672            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
673            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
674        n_seeds : int, default 1
675            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
676            `test_episode#0` and `test_episode#1`
677        force : bool, default False
678            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
679        suppress_name_check : bool, default False
680            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
681            why they are usually forbidden)
682        remove_saved_features : bool, default False
683            if `True`, the dataset will be deleted after training
684        mask_name : str, optional
685            the name of the real_lens to apply
686        autostop_metric : str, optional
687            the autostop metric (can be any one of the tracked metrics of `'loss'`)
688        autostop_interval : int, default 50
689            the number of epochs to average the autostop metric over
690        autostop_threshold : float, default 0.001
691            the autostop difference threshold
692        loading_bar : bool, default False
693            if `True`, a loading bar will be displayed
694        trial : tuple, optional
695            a tuple of (trial, metric) for hyperparameter search
696
697        Returns
698        -------
699        TaskDispatcher
700            the `TaskDispatcher` object
701
702        """
703
704        import gc
705
706        gc.collect()
707        if torch.cuda.is_available():
708            torch.cuda.empty_cache()
709
710        if type(n_seeds) is not int or n_seeds < 1:
711            raise ValueError(
712                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
713            )
714        if n_seeds > 1 and mask_name is not None:
715            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
716        self._check_episode_validity(
717            episode_name, allow_doublecolon=suppress_name_check, force=force
718        )
719        load_runs = self._episodes().get_runs(load_episode)
720        if len(load_runs) > 1:
721            task = self.run_episodes(
722                episode_names=[
723                    f'{episode_name}#{run.split("#")[-1]}' for run in load_runs
724                ],
725                load_episodes=load_runs,
726                parameters_updates=[parameters_update for _ in load_runs],
727                load_epochs=[load_epoch for _ in load_runs],
728                load_searches=[load_search for _ in load_runs],
729                load_parameters=[load_parameters for _ in load_runs],
730                round_to_binary=[round_to_binary for _ in load_runs],
731                load_strict=[load_strict for _ in load_runs],
732                suppress_name_check=True,
733                force=force,
734                remove_saved_features=False,
735            )
736            if remove_saved_features:
737                self._remove_stores(
738                    {
739                        "general": task.general_parameters,
740                        "data": task.data_parameters,
741                        "features": task.feature_parameters,
742                    }
743                )
744            if n_seeds > 1:
745                warnings.warn(
746                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
747                )
748        elif n_seeds > 1:
749
750            self.run_episodes(
751                episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)],
752                load_episodes=[load_episode for _ in range(n_seeds)],
753                parameters_updates=[parameters_update for _ in range(n_seeds)],
754                load_epochs=[load_epoch for _ in range(n_seeds)],
755                load_searches=[load_search for _ in range(n_seeds)],
756                load_parameters=[load_parameters for _ in range(n_seeds)],
757                round_to_binary=[round_to_binary for _ in range(n_seeds)],
758                load_strict=[load_strict for _ in range(n_seeds)],
759                suppress_name_check=True,
760                force=force,
761                remove_saved_features=remove_saved_features,
762            )
763        else:
764            print(f"TRAINING {episode_name}")
765            try:
766                task, parameters = self._make_task_training(
767                    episode_name,
768                    load_episode,
769                    parameters_update,
770                    load_epoch,
771                    load_search,
772                    load_parameters,
773                    round_to_binary,
774                    continuing=False,
775                    task=task,
776                    mask_name=mask_name,
777                    load_strict=load_strict,
778                )
779                self._save_episode(
780                    episode_name,
781                    parameters,
782                    task.behaviors_dict(),
783                    norm_stats=task.get_normalization_stats(),
784                )
785                time_start = time.time()
786                if trial is not None:
787                    trial, metric = trial
788                else:
789                    trial, metric = None, None
790                logs = task.train(
791                    autostop_metric=autostop_metric,
792                    autostop_interval=autostop_interval,
793                    autostop_threshold=autostop_threshold,
794                    loading_bar=loading_bar,
795                    trial=trial,
796                    optimized_metric=metric,
797                )
798                time_end = time.time()
799                time_total = time_end - time_start
800                hours = int(time_total // 3600)
801                time_total -= hours * 3600
802                minutes = int(time_total // 60)
803                time_total -= minutes * 60
804                seconds = int(time_total)
805                training_time = f"{hours}:{minutes:02}:{seconds:02}"
806                self._update_episode_results(episode_name, logs, training_time)
807                if remove_saved_features:
808                    self._remove_stores(parameters)
809                print("\n")
810                return task
811
812            except Exception as e:
813                if isinstance(e, optuna.exceptions.TrialPruned):
814                    raise e
815                else:
816                    # if str(e) != f"The {episode_name} episode name is already in use!":
817                    #     self.remove_episode(episode_name)
818                    raise RuntimeError(f"Episode {episode_name} could not run")

Run an episode.

The task parameters are read from the config files and then updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.

You can use the autostop parameters to finish training when the parameters are not improving. It will be stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than the average over the previous autostop_interval epochs + autostop_threshold. For example, if the current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.

Parameters

episode_name : str the episode name load_episode : str, optional the (previously run) episode name to load the model from; if the episode has multiple runs, the new episode will have the same number of runs, each starting with one of the pre-trained models parameters_update : dict, optional the dictionary used to update the parameters from the config files task : TaskDispatcher, optional a pre-existing TaskDispatcher object (if provided, the method will update it instead of creating a new instance) load_epoch : int, optional the epoch to load (if load_episodes is not None); if not provided, the last epoch is used load_search : str, optional the hyperparameter search result to load load_parameters : list, optional a list of string names of the parameters to load from load_search (if not provided, all parameters are loaded) round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two load_strict : bool, default True if False, matching weights will be loaded from load_episode and differences in parameter name lists and weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError n_seeds : int, default 1 the number of runs to perform; if n_seeds > 1, the episodes will be named episode_name#run_index, e.g. test_episode#0 and test_episode#1 force : bool, default False if True and an episode with name episode_name already exists, it will be overwritten (use with caution!) suppress_name_check : bool, default False if True, episode names with a double colon are allowed (please don't use this option unless you understand why they are usually forbidden) remove_saved_features : bool, default False if True, the dataset will be deleted after training mask_name : str, optional the name of the real_lens to apply autostop_metric : str, optional the autostop metric (can be any one of the tracked metrics of 'loss') autostop_interval : int, default 50 the number of epochs to average the autostop metric over autostop_threshold : float, default 0.001 the autostop difference threshold loading_bar : bool, default False if True, a loading bar will be displayed trial : tuple, optional a tuple of (trial, metric) for hyperparameter search

Returns

TaskDispatcher the TaskDispatcher object

def run_episodes( self, episode_names: List, load_episodes: List = None, parameters_updates: List = None, load_epochs: List = None, load_searches: List = None, load_parameters: List = None, round_to_binary: List = None, load_strict: List = None, force: bool = False, suppress_name_check: bool = False, remove_saved_features: bool = False) -> dlc2action.task.task_dispatcher.TaskDispatcher:
820    def run_episodes(
821        self,
822        episode_names: List,
823        load_episodes: List = None,
824        parameters_updates: List = None,
825        load_epochs: List = None,
826        load_searches: List = None,
827        load_parameters: List = None,
828        round_to_binary: List = None,
829        load_strict: List = None,
830        force: bool = False,
831        suppress_name_check: bool = False,
832        remove_saved_features: bool = False,
833    ) -> TaskDispatcher:
834        """Run multiple episodes in sequence (and re-use previously loaded information).
835
836        For each episode, the task parameters are read from the config files and then updated with the
837        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
838        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
839        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
840        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
841        data parameters are used.
842
843        Parameters
844        ----------
845        episode_names : list
846            a list of strings of episode names
847        load_episodes : list, optional
848            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
849            the new episode will have the same number of runs, each starting with one of the pre-trained models
850        parameters_updates : list, optional
851            a list of dictionaries used to update the parameters from the config
852        load_epochs : list, optional
853            a list of integers used to specify the epoch to load (if load_episodes is not None)
854        load_searches : list, optional
855            a list of strings of hyperparameter search results to load
856        load_parameters : list, optional
857            a list of lists of string names of the parameters to load from the searches
858        round_to_binary : list, optional
859            a list of string names of the loaded parameters that should be rounded to the nearest power of two
860        load_strict : list, optional
861            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
862            the corresponding episode and differences in parameter name lists and
863            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
864            every episode)
865        force : bool, default False
866            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
867        suppress_name_check : bool, default False
868            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
869            why they are usually forbidden)
870        remove_saved_features : bool, default False
871            if `True`, the dataset will be deleted after training
872
873        Returns
874        -------
875        TaskDispatcher
876            the task dispatcher object
877
878        """
879        task = None
880        if load_searches is None:
881            load_searches = [None for _ in episode_names]
882        if load_episodes is None:
883            load_episodes = [None for _ in episode_names]
884        if parameters_updates is None:
885            parameters_updates = [None for _ in episode_names]
886        if load_parameters is None:
887            load_parameters = [None for _ in episode_names]
888        if load_epochs is None:
889            load_epochs = [None for _ in episode_names]
890        if load_strict is None:
891            load_strict = [True for _ in episode_names]
892        for (
893            parameters_update,
894            episode_name,
895            load_episode,
896            load_epoch,
897            load_search,
898            load_parameters_list,
899            load_strict_value,
900        ) in zip(
901            parameters_updates,
902            episode_names,
903            load_episodes,
904            load_epochs,
905            load_searches,
906            load_parameters,
907            load_strict,
908        ):
909            task = self.run_episode(
910                episode_name,
911                load_episode,
912                parameters_update,
913                task,
914                load_epoch,
915                load_search,
916                load_parameters_list,
917                round_to_binary,
918                load_strict_value,
919                suppress_name_check=suppress_name_check,
920                force=force,
921                remove_saved_features=remove_saved_features,
922            )
923        return task

Run multiple episodes in sequence (and re-use previously loaded information).

For each episode, the task parameters are read from the config files and then updated with the parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.

Parameters

episode_names : list a list of strings of episode names load_episodes : list, optional a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, the new episode will have the same number of runs, each starting with one of the pre-trained models parameters_updates : list, optional a list of dictionaries used to update the parameters from the config load_epochs : list, optional a list of integers used to specify the epoch to load (if load_episodes is not None) load_searches : list, optional a list of strings of hyperparameter search results to load load_parameters : list, optional a list of lists of string names of the parameters to load from the searches round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two load_strict : list, optional a list of boolean values specifying weight loading policy: if False, matching weights will be loaded from the corresponding episode and differences in parameter name lists and weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError (by default True for every episode) force : bool, default False if True and an episode name is already taken, it will be overwritten (use with caution!) suppress_name_check : bool, default False if True, episode names with a double colon are allowed (please don't use this option unless you understand why they are usually forbidden) remove_saved_features : bool, default False if True, the dataset will be deleted after training

Returns

TaskDispatcher the task dispatcher object

def continue_episode( self, episode_name: str, num_epochs: int = None, task: dlc2action.task.task_dispatcher.TaskDispatcher = None, n_seeds: int = 1, remove_saved_features: bool = False, device: str = 'cuda', num_cpus: int = None) -> dlc2action.task.task_dispatcher.TaskDispatcher:
 925    def continue_episode(
 926        self,
 927        episode_name: str,
 928        num_epochs: int = None,
 929        task: TaskDispatcher = None,
 930        n_seeds: int = 1,
 931        remove_saved_features: bool = False,
 932        device: str = "cuda",
 933        num_cpus: int = None,
 934    ) -> TaskDispatcher:
 935        """Load an older episode and continue running from the latest checkpoint.
 936
 937        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
 938
 939        Parameters
 940        ----------
 941        episode_name : str
 942            the name of the episode to continue
 943        num_epochs : int, optional
 944            the new number of epochs
 945        task : TaskDispatcher, optional
 946            a pre-existing task; if provided, the method will update the task instead of creating a new one
 947            (this might save time, mainly on dataset loading)
 948        n_seeds : int, default 1
 949            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 950            `test_episode#0` and `test_episode#1`
 951        remove_saved_features : bool, default False
 952            if `True`, pre-computed features will be deleted after the run
 953        device : str, default "cuda"
 954            the torch device to use
 955        num_cpus : int, optional
 956            the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used
 957
 958        Returns
 959        -------
 960        TaskDispatcher
 961            the task dispatcher
 962
 963        """
 964        runs = self._episodes().get_runs(episode_name)
 965        for run in runs:
 966            print(f"TRAINING {run}")
 967            if num_epochs is None and not self._episode(run).unfinished():
 968                continue
 969            parameters_update = {
 970                "training": {
 971                    "num_epochs": num_epochs,
 972                    "device": device,
 973                },
 974                "general": {"num_cpus": num_cpus},
 975            }
 976            task, parameters = self._make_task_training(
 977                run,
 978                load_episode=run,
 979                parameters_update=parameters_update,
 980                continuing=True,
 981                task=task,
 982            )
 983            time_start = time.time()
 984            logs = task.train()
 985            time_end = time.time()
 986            old_time = self._training_time(run)
 987            if not np.isnan(old_time):
 988                time_end += old_time
 989                time_total = time_end - time_start
 990                hours = int(time_total // 3600)
 991                time_total -= hours * 3600
 992                minutes = int(time_total // 60)
 993                time_total -= minutes * 60
 994                seconds = int(time_total)
 995                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 996            else:
 997                training_time = np.nan
 998            self._save_episode(
 999                run,
1000                parameters,
1001                task.behaviors_dict(),
1002                suppress_validation=True,
1003                training_time=training_time,
1004                norm_stats=task.get_normalization_stats(),
1005            )
1006            self._update_episode_results(run, logs)
1007            print("\n")
1008        if len(runs) < n_seeds:
1009            for i in range(len(runs), n_seeds):
1010                self.run_episode(
1011                    f"{episode_name}#{i}",
1012                    parameters_update=self._episodes().load_parameters(runs[0]),
1013                    task=task,
1014                    suppress_name_check=True,
1015                )
1016        if remove_saved_features:
1017            self._remove_stores(parameters)
1018        return task

Load an older episode and continue running from the latest checkpoint.

All parameters as well as the model and optimizer state dictionaries are loaded from the episode.

Parameters

episode_name : str the name of the episode to continue num_epochs : int, optional the new number of epochs task : TaskDispatcher, optional a pre-existing task; if provided, the method will update the task instead of creating a new one (this might save time, mainly on dataset loading) n_seeds : int, default 1 the number of runs to perform; if n_seeds > 1, the episodes will be named episode_name#run_index, e.g. test_episode#0 and test_episode#1 remove_saved_features : bool, default False if True, pre-computed features will be deleted after the run device : str, default "cuda" the torch device to use num_cpus : int, optional the number of CPUs to use for data loading; if None, the number of available CPUs will be used

Returns

TaskDispatcher the task dispatcher

def run_prediction( self, prediction_name: str, episode_names: List, load_epochs: List = None, parameters_update: Dict = None, augment_n: int = 10, data_path: str = None, mode: str = 'all', file_paths: Set = None, remove_saved_features: bool = False, frame_number_map_file: str = None, force: bool = False, embedding: bool = False) -> None:
1260    def run_prediction(
1261        self,
1262        prediction_name: str,
1263        episode_names: List,
1264        load_epochs: List = None,
1265        parameters_update: Dict = None,
1266        augment_n: int = 10,
1267        data_path: str = None,
1268        mode: str = "all",
1269        file_paths: Set = None,
1270        remove_saved_features: bool = False,
1271        frame_number_map_file: str = None,
1272        force: bool = False,
1273        embedding: bool = False,
1274    ) -> None:
1275        """Load models from previously run episodes to generate a prediction.
1276
1277        The probabilities predicted by the models are averaged.
1278        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1279        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1280        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1281        are the prediction arrays.
1282
1283        Parameters
1284        ----------
1285        prediction_name : str
1286            the name of the prediction
1287        episode_names : list
1288            a list of string episode names to load the models from
1289        load_epochs : list or int, optional
1290            a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes
1291        parameters_update : dict, optional
1292            a dictionary of parameter updates
1293        augment_n : int, default 10
1294            the number of augmentations to average over
1295        data_path : str, optional
1296            the data path to run the prediction for
1297        mode : {'all', 'test', 'val', 'train'}
1298            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1299        file_paths : set, optional
1300            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1301            for
1302        remove_saved_features : bool, default False
1303            if `True`, pre-computed features will be deleted
1304        submission : bool, default False
1305            if `True`, a MABe-22 style submission file is generated
1306        frame_number_map_file : str, optional
1307            path to the frame number map file
1308        force : bool, default False
1309            if `True`, existing prediction with this name will be overwritten
1310        embedding : bool, default False
1311            if `True`, the prediction is made for the embedding task
1312
1313        """
1314        self._check_prediction_validity(prediction_name, force=force)
1315        print(f"PREDICTION {prediction_name}")
1316        task, parameters, mode, prediction, inference_time, behavior_dict = (
1317            self._make_prediction(
1318                prediction_name,
1319                episode_names,
1320                load_epochs,
1321                parameters_update,
1322                data_path,
1323                file_paths,
1324                mode,
1325                augment_n,
1326                evaluate=False,
1327                embedding=embedding,
1328            )
1329        )
1330        predicted = task.dataset(mode).generate_full_length_prediction(prediction)
1331
1332        if remove_saved_features:
1333            self._remove_stores(parameters)
1334
1335        self._save_prediction(
1336            prediction_name,
1337            predicted,
1338            parameters,
1339            task,
1340            mode,
1341            embedding,
1342            inference_time,
1343            behavior_dict,
1344        )
1345        print("\n")

Load models from previously run episodes to generate a prediction.

The probabilities predicted by the models are averaged. Unless submission is True, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level keys are the video ids, the second-level keys are the clip ids (like individual names) and the values are the prediction arrays.

Parameters

prediction_name : str the name of the prediction episode_names : list a list of string episode names to load the models from load_epochs : list or int, optional a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes parameters_update : dict, optional a dictionary of parameter updates augment_n : int, default 10 the number of augmentations to average over data_path : str, optional the data path to run the prediction for mode : {'all', 'test', 'val', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) file_paths : set, optional a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for remove_saved_features : bool, default False if True, pre-computed features will be deleted submission : bool, default False if True, a MABe-22 style submission file is generated frame_number_map_file : str, optional path to the frame number map file force : bool, default False if True, existing prediction with this name will be overwritten embedding : bool, default False if True, the prediction is made for the embedding task

def evaluate_prediction( self, prediction_name: str, parameters_update: Dict = None, data_path: str = None, annotation_path: str = None, file_paths: Set = None, mode: str = None, remove_saved_features: bool = False, annotation_type: str = 'none', num_classes: int = None) -> Tuple[float, dict]:
1347    def evaluate_prediction(
1348        self,
1349        prediction_name: str,
1350        parameters_update: Dict = None,
1351        data_path: str = None,
1352        annotation_path: str = None,
1353        file_paths: Set = None,
1354        mode: str = None,
1355        remove_saved_features: bool = False,
1356        annotation_type: str = "none",
1357        num_classes: int = None,  # Set when using data_path
1358    ) -> Tuple[float, dict]:
1359        """Make predictions and evaluate them
1360        inputs:
1361            prediction_name (str): the name of the prediction
1362            parameters_update (dict): a dictionary of parameter updates
1363            data_path (str): the data path to run the prediction for
1364            annotation_path (str): the annotation path to run the prediction for
1365            file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for
1366            mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1367            remove_saved_features (bool): if `True`, pre-computed features will be deleted
1368            annotation_type (str): the type of annotation to use for evaluation
1369            num_classes (int): the number of classes in the dataset, must be set with data_path
1370        outputs:
1371            results (dict): a dictionary of average values of metric functions
1372        """
1373
1374        prediction_path = os.path.join(
1375            self.project_path, "results", "predictions", f"{prediction_name}"
1376        )
1377        prediction_dict = {}
1378        for prediction_file_path in [
1379            os.path.join(prediction_path, i) for i in os.listdir(prediction_path)
1380        ]:
1381            with open(os.path.join(prediction_file_path), "rb") as f:
1382                prediction = pickle.load(f)
1383            video_id = os.path.basename(prediction_file_path).split(
1384                "_" + prediction_name
1385            )[0]
1386            prediction_dict[video_id] = prediction
1387        if parameters_update is None:
1388            parameters_update = {}
1389        parameters_update = self._update(
1390            self._predictions().load_parameters(prediction_name), parameters_update
1391        )
1392        parameters_update.pop("model")
1393        if not data_path is None:
1394            assert (
1395                not num_classes is None
1396            ), "num_classes must be provided if data_path is provided"
1397            parameters_update["general"]["num_classes"] = num_classes + int(
1398                parameters_update["general"]["exclusive"]
1399            )
1400        task, parameters, mode = self._make_task_prediction(
1401            "_",
1402            load_episode=None,
1403            parameters_update=parameters_update,
1404            data_path=data_path,
1405            annotation_path=annotation_path,
1406            file_paths=file_paths,
1407            mode=mode,
1408            annotation_type=annotation_type,
1409        )
1410        results = task.evaluate_prediction(prediction_dict, data=mode)
1411        if remove_saved_features:
1412            self._remove_stores(parameters)
1413        results = Project._reformat_results(
1414            results[1],
1415            task.behaviors_dict(),
1416            exclusive=task.general_parameters["exclusive"],
1417        )
1418        return results

Make predictions and evaluate them inputs: prediction_name (str): the name of the prediction parameters_update (dict): a dictionary of parameter updates data_path (str): the data path to run the prediction for annotation_path (str): the annotation path to run the prediction for file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None) remove_saved_features (bool): if True, pre-computed features will be deleted annotation_type (str): the type of annotation to use for evaluation num_classes (int): the number of classes in the dataset, must be set with data_path outputs: results (dict): a dictionary of average values of metric functions

def evaluate( self, episode_names: List, load_epochs: List = None, augment_n: int = 0, data_path: str = None, file_paths: Set = None, mode: str = None, parameters_update: Dict = None, multiple_episode_policy: str = 'average', remove_saved_features: bool = False, skip_updating_meta: bool = True, annotation_type: str = 'none') -> Dict:
1420    def evaluate(
1421        self,
1422        episode_names: List,
1423        load_epochs: List = None,
1424        augment_n: int = 0,
1425        data_path: str = None,
1426        file_paths: Set = None,
1427        mode: str = None,
1428        parameters_update: Dict = None,
1429        multiple_episode_policy: str = "average",
1430        remove_saved_features: bool = False,
1431        skip_updating_meta: bool = True,
1432        annotation_type: str = "none",
1433    ) -> Dict:
1434        """Load one or several models from previously run episodes to make an evaluation.
1435
1436        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1437
1438        Parameters
1439        ----------
1440        episode_names : list
1441            a list of string episode names to load the models from
1442        load_epochs : list, optional
1443            a list of integer epoch indices to load the model from; if None, the last ones are used
1444        augment_n : int, default 0
1445            the number of augmentations to average over
1446        data_path : str, optional
1447            the data path to run the prediction for
1448        file_paths : set, optional
1449            a set of files to run the prediction for
1450        mode : {'test', 'val', 'train', 'all'}
1451            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1452            by default 'test' if test subset is not empty and 'val' otherwise)
1453        parameters_update : dict, optional
1454            a dictionary with parameter updates (cannot change model parameters)
1455        multiple_episode_policy : {'average', 'statistics'}
1456            the policy to use when multiple episodes are provided
1457        remove_saved_features : bool, default False
1458            if `True`, the dataset will be deleted
1459        skip_updating_meta : bool, default True
1460            if `True`, the meta file will not be updated with the computed metrics
1461
1462        Returns
1463        -------
1464        metric : dict
1465            a dictionary of average values of metric functions
1466
1467        """
1468        names = []
1469        for episode_name in episode_names:
1470            names += self._episodes().get_runs(episode_name)
1471        if len(set(episode_names)) == 1:
1472            print(f"EVALUATION {episode_names[0]}")
1473        else:
1474            print(f"EVALUATION {episode_names}")
1475        if len(names) > 1:
1476            evaluate = True
1477        else:
1478            evaluate = False
1479        if multiple_episode_policy == "average":
1480            task, parameters, mode, prediction, inference_time, behavior_dict = (
1481                self._make_prediction(
1482                    "_",
1483                    episode_names,
1484                    load_epochs,
1485                    parameters_update,
1486                    mode=mode,
1487                    data_path=data_path,
1488                    file_paths=file_paths,
1489                    augment_n=augment_n,
1490                    evaluate=evaluate,
1491                    annotation_type=annotation_type,
1492                )
1493            )
1494            print("EVALUATE PREDICTION:")
1495            indices = [
1496                list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict))
1497            ]
1498            _, results = task.evaluate_prediction(
1499                prediction, data=mode, indices=indices
1500            )
1501            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1502                self._update_episode_metrics(names[0], results)
1503            results = Project._reformat_results(
1504                results,
1505                behavior_dict,
1506                exclusive=task.general_parameters["exclusive"],
1507            )
1508
1509        elif multiple_episode_policy == "statistics":
1510            values = defaultdict(lambda: [])
1511            task = None
1512            for name in names:
1513                (
1514                    task,
1515                    parameters,
1516                    mode,
1517                    prediction,
1518                    inference_time,
1519                    behavior_dict,
1520                ) = self._make_prediction(
1521                    "_",
1522                    [name],
1523                    load_epochs,
1524                    parameters_update,
1525                    mode=mode,
1526                    data_path=data_path,
1527                    file_paths=file_paths,
1528                    augment_n=augment_n,
1529                    evaluate=evaluate,
1530                    task=task,
1531                )
1532                _, metrics = task.evaluate_prediction(
1533                    prediction, data=mode, indices=list(behavior_dict.keys())
1534                )
1535                for name, value in metrics.items():
1536                    values[name].append(value)
1537                if mode == "val" and not skip_updating_meta:
1538                    self._update_episode_metrics(name, metrics)
1539            results = defaultdict(lambda: {})
1540            mean_string = ""
1541            std_string = ""
1542            for key, value_list in values.items():
1543                results[key]["mean"] = np.mean(value_list)
1544                results[key]["std"] = np.std(value_list)
1545                results[key]["all"] = value_list
1546                mean_string += f"{key} {np.mean(value_list):.3f}, "
1547                std_string += f"{key} {np.std(value_list):.3f}, "
1548            print("MEAN:")
1549            print(mean_string)
1550            print("STD:")
1551            print(std_string)
1552        else:
1553            raise ValueError(
1554                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1555                f"from ['average', 'statistics']"
1556            )
1557        if len(names) > 0 and remove_saved_features:
1558            self._remove_stores(parameters)
1559        print(f"Inference time: {inference_time}")
1560        print("\n")
1561        return results

Load one or several models from previously run episodes to make an evaluation.

By default it will run on the test (or validation, if there is no test) subset of the project dataset.

Parameters

episode_names : list a list of string episode names to load the models from load_epochs : list, optional a list of integer epoch indices to load the model from; if None, the last ones are used augment_n : int, default 0 the number of augmentations to average over data_path : str, optional the data path to run the prediction for file_paths : set, optional a set of files to run the prediction for mode : {'test', 'val', 'train', 'all'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None; by default 'test' if test subset is not empty and 'val' otherwise) parameters_update : dict, optional a dictionary with parameter updates (cannot change model parameters) multiple_episode_policy : {'average', 'statistics'} the policy to use when multiple episodes are provided remove_saved_features : bool, default False if True, the dataset will be deleted skip_updating_meta : bool, default True if True, the meta file will not be updated with the computed metrics

Returns

metric : dict a dictionary of average values of metric functions

def run_suggestion( self, suggestions_name: str, error_episode: str = None, error_load_epoch: int = None, error_class: str = None, suggestions_prediction: str = None, suggestion_episodes: List = [None], suggestion_load_epoch: int = None, suggestion_classes: List = None, error_threshold: float = 0.5, error_threshold_diff: float = 0.1, error_hysteresis: bool = False, suggestion_threshold: Union[float, List] = 0.5, suggestion_threshold_diff: Union[float, List] = 0.1, suggestion_hysteresis: Union[bool, List] = True, min_frames_suggestion: int = 10, min_frames_al: int = 30, visibility_min_score: float = 0, visibility_min_frac: float = 0.7, augment_n: int = 0, exclude_classes: List = None, exclude_threshold: Union[float, List] = 0.6, exclude_threshold_diff: Union[float, List] = 0.1, exclude_hysteresis: Union[bool, List] = False, include_classes: List = None, include_threshold: Union[float, List] = 0.4, include_threshold_diff: Union[float, List] = 0.1, include_hysteresis: Union[bool, List] = False, data_path: str = None, file_paths: Set = None, parameters_update: Dict = None, mode: str = 'all', force: bool = False, remove_saved_features: bool = False, cut_annotated: bool = False, background_threshold: float = None) -> None:
1563    def run_suggestion(
1564        self,
1565        suggestions_name: str,
1566        error_episode: str = None,
1567        error_load_epoch: int = None,
1568        error_class: str = None,
1569        suggestions_prediction: str = None,
1570        suggestion_episodes: List = [None],
1571        suggestion_load_epoch: int = None,
1572        suggestion_classes: List = None,
1573        error_threshold: float = 0.5,
1574        error_threshold_diff: float = 0.1,
1575        error_hysteresis: bool = False,
1576        suggestion_threshold: Union[float, List] = 0.5,
1577        suggestion_threshold_diff: Union[float, List] = 0.1,
1578        suggestion_hysteresis: Union[bool, List] = True,
1579        min_frames_suggestion: int = 10,
1580        min_frames_al: int = 30,
1581        visibility_min_score: float = 0,
1582        visibility_min_frac: float = 0.7,
1583        augment_n: int = 0,
1584        exclude_classes: List = None,
1585        exclude_threshold: Union[float, List] = 0.6,
1586        exclude_threshold_diff: Union[float, List] = 0.1,
1587        exclude_hysteresis: Union[bool, List] = False,
1588        include_classes: List = None,
1589        include_threshold: Union[float, List] = 0.4,
1590        include_threshold_diff: Union[float, List] = 0.1,
1591        include_hysteresis: Union[bool, List] = False,
1592        data_path: str = None,
1593        file_paths: Set = None,
1594        parameters_update: Dict = None,
1595        mode: str = "all",
1596        force: bool = False,
1597        remove_saved_features: bool = False,
1598        cut_annotated: bool = False,
1599        background_threshold: float = None,
1600    ) -> None:
1601        """Create active learning and suggestion files.
1602
1603        Generate predictions with the error and suggestion model and use them to create
1604        suggestion files for the labeling interface. Those files will render as suggested labels
1605        at intervals with high pose estimation quality. Quality here is defined by probability of error
1606        (predicted by the error model) and visibility parameters.
1607
1608        If `error_episode` or `exclude_classes` is not `None`,
1609        an active learning file will be created as well (with frames with high predicted probability of classes
1610        from `exclude_classes` and/or errors excluded from the active learning intervals).
1611
1612        In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals)
1613        you can apply one of three methods.
1614
1615        - **Simple threshold**
1616
1617            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold`
1618            parameter to $\alpha$.
1619            In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will
1620            be considered labeled.
1621
1622        - **Hysteresis threshold**
1623
1624            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1625            parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$.
1626            Now intervals will be marked with a label if the probability of that label for all frames is higher
1627            than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$.
1628
1629        - **Max hysteresis threshold**
1630
1631            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1632            parameter to $\alpha$ and the `threshold_diff` parameter to `None`.
1633            With this combination intervals are marked with a label if that label is more likely than any other
1634            for all frames in this interval and at for at least one of those frames its probability is higher than
1635            $\alpha$.
1636
1637        Parameters
1638        ----------
1639        suggestions_name : str
1640            the name of the suggestions
1641        error_episode : str, optional
1642            the name of the episode where the error model should be loaded from
1643        error_load_epoch : int, optional
1644            the epoch the error model should be loaded from
1645        error_class : str, optional
1646            the name of the error class (in `error_episode`)
1647        suggestions_prediction : str, optional
1648            the name of the predictions that should be used for the suggestion model
1649        suggestion_episodes : list, optional
1650            the names of the episodes where the suggestion models should be loaded from
1651        suggestion_load_epoch : int, optional
1652            the epoch the suggestion model should be loaded from
1653        suggestion_classes : list, optional
1654            a list of string names of the classes that should be suggested (in `suggestion_episode`)
1655        error_threshold : float, default 0.5
1656            the hard threshold for error prediction
1657        error_threshold_diff : float, default 0.1
1658            the difference between soft and hard thresholds for error prediction (in case hysteresis is used)
1659        error_hysteresis : bool, default False
1660            if True, hysteresis is used for error prediction
1661        suggestion_threshold : float | list, default 0.5
1662            the hard threshold for class prediction (use a list to set different rules for different classes)
1663        suggestion_threshold_diff : float | list, default 0.1
1664            the difference between soft and hard thresholds for class prediction (in case hysteresis is used;
1665            use a list to set different rules for different classes)
1666        suggestion_hysteresis : bool | list, default True
1667            if True, hysteresis is used for class prediction (use a list to set different rules for different classes)
1668        min_frames_suggestion : int, default 10
1669            only actions longer than this number of frames will be suggested
1670        min_frames_al : int, default 30
1671            only active learning intervals longer than this number of frames will be suggested
1672        visibility_min_score : float, default 0
1673            the minimum visibility score for visibility filtering
1674        visibility_min_frac : float, default 0.7
1675            the minimum fraction of visible frames for visibility filtering
1676        augment_n : int, default 10
1677            the number of augmentations to average the predictions over
1678        exclude_classes : list, optional
1679            a list of string names of classes that should be excluded from the active learning intervals
1680        exclude_threshold : float | list, default 0.6
1681            the hard threshold for excluded class prediction (use a list to set different rules for different classes)
1682        exclude_threshold_diff : float | list, default 0.1
1683            the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used)
1684        exclude_hysteresis : bool | list, default False
1685            if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes)
1686        include_classes : list, optional
1687            a list of string names of classes that should be included into the active learning intervals
1688        include_threshold : float | list, default 0.6
1689            the hard threshold for included class prediction (use a list to set different rules for different classes)
1690        include_threshold_diff : float | list, default 0.1
1691            the difference between soft and hard thresholds for included class prediction (in case hysteresis is used)
1692        include_hysteresis : bool | list, default False
1693            if True, hysteresis is used for included class prediction (use a list to set different rules for different classes)
1694        data_path : str, optional
1695            the data path to run the prediction for
1696        file_paths : set, optional
1697            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1698            for
1699        parameters_update : dict, optional
1700            the parameters update dictionary
1701        mode : {'all', 'test', 'val', 'train'}
1702            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1703        force : bool, default False
1704            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
1705        remove_saved_features : bool, default False
1706            if `True`, the dataset will be deleted.
1707        cut_annotated : bool, default False
1708            if `True`, annotated frames will be cut from the suggestions
1709        background_threshold : float, default 0.5
1710            the threshold for background prediction
1711
1712        """
1713        self._check_suggestions_validity(suggestions_name, force=force)
1714        if any([x is None for x in suggestion_episodes]):
1715            suggestion_episodes = None
1716        if error_episode is None and (
1717            suggestion_episodes is None and suggestions_prediction is None
1718        ):
1719            raise ValueError(
1720                "Both error_episode and suggestion_episode parameters cannot be None at the same time"
1721            )
1722        print(f"SUGGESTION {suggestions_name}")
1723        task = None
1724        if suggestion_classes is None:
1725            suggestion_classes = []
1726        if exclude_classes is None:
1727            exclude_classes = []
1728        if include_classes is None:
1729            include_classes = []
1730        if isinstance(suggestion_threshold, list):
1731            if len(suggestion_threshold) != len(suggestion_classes):
1732                raise ValueError(
1733                    "The suggestion_threshold parameter has to be either a float value or a list of "
1734                    f"float values of the same length as suggestion_classes (got a list of length "
1735                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1736                )
1737        else:
1738            suggestion_threshold = [suggestion_threshold for _ in suggestion_classes]
1739        if isinstance(suggestion_threshold_diff, list):
1740            if len(suggestion_threshold_diff) != len(suggestion_classes):
1741                raise ValueError(
1742                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1743                    f"float values of the same length as suggestion_classes (got a list of length "
1744                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1745                )
1746        else:
1747            suggestion_threshold_diff = [
1748                suggestion_threshold_diff for _ in suggestion_classes
1749            ]
1750        if isinstance(suggestion_hysteresis, list):
1751            if len(suggestion_hysteresis) != len(suggestion_classes):
1752                raise ValueError(
1753                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1754                    f"float values of the same length as suggestion_classes (got a list of length "
1755                    f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)"
1756                )
1757        else:
1758            suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes]
1759        if isinstance(exclude_threshold, list):
1760            if len(exclude_threshold) != len(exclude_classes):
1761                raise ValueError(
1762                    "The exclude_threshold parameter has to be either a float value or a list of "
1763                    f"float values of the same length as exclude_classes (got a list of length "
1764                    f"{len(exclude_threshold)} for {len(exclude_classes)} classes)"
1765                )
1766        else:
1767            exclude_threshold = [exclude_threshold for _ in exclude_classes]
1768        if isinstance(exclude_threshold_diff, list):
1769            if len(exclude_threshold_diff) != len(exclude_classes):
1770                raise ValueError(
1771                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1772                    f"float values of the same length as exclude_classes (got a list of length "
1773                    f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)"
1774                )
1775        else:
1776            exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes]
1777        if isinstance(exclude_hysteresis, list):
1778            if len(exclude_hysteresis) != len(exclude_classes):
1779                raise ValueError(
1780                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1781                    f"float values of the same length as suggestion_classes (got a list of length "
1782                    f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)"
1783                )
1784        else:
1785            exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes]
1786        if isinstance(include_threshold, list):
1787            if len(include_threshold) != len(include_classes):
1788                raise ValueError(
1789                    "The exclude_threshold parameter has to be either a float value or a list of "
1790                    f"float values of the same length as exclude_classes (got a list of length "
1791                    f"{len(include_threshold)} for {len(include_classes)} classes)"
1792                )
1793        else:
1794            include_threshold = [include_threshold for _ in include_classes]
1795        if isinstance(include_threshold_diff, list):
1796            if len(include_threshold_diff) != len(include_classes):
1797                raise ValueError(
1798                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1799                    f"float values of the same length as exclude_classes (got a list of length "
1800                    f"{len(include_threshold_diff)} for {len(include_classes)} classes)"
1801                )
1802        else:
1803            include_threshold_diff = [include_threshold_diff for _ in include_classes]
1804        if isinstance(include_hysteresis, list):
1805            if len(include_hysteresis) != len(include_classes):
1806                raise ValueError(
1807                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1808                    f"float values of the same length as suggestion_classes (got a list of length "
1809                    f"{len(include_hysteresis)} for {len(include_classes)} classes)"
1810                )
1811        else:
1812            include_hysteresis = [include_hysteresis for _ in include_classes]
1813        if (suggestion_episodes is None and suggestions_prediction is None) and len(
1814            exclude_classes
1815        ) > 0:
1816            raise ValueError(
1817                "In order to exclude classes from the active learning intervals you need to set the "
1818                "suggestion_episode parameter"
1819            )
1820
1821        task = None
1822        if error_episode is not None:
1823            task, parameters, mode = self._make_task_prediction(
1824                prediction_name=suggestions_name,
1825                load_episode=error_episode,
1826                parameters_update=parameters_update,
1827                load_epoch=error_load_epoch,
1828                data_path=data_path,
1829                mode=mode,
1830                file_paths=file_paths,
1831                task=task,
1832            )
1833            predicted_error = task.predict(
1834                data=mode,
1835                raw_output=True,
1836                apply_primary_function=True,
1837                augment_n=augment_n,
1838            )
1839        else:
1840            predicted_error = None
1841
1842        if suggestion_episodes is not None:
1843            (
1844                task,
1845                parameters,
1846                mode,
1847                predicted_classes,
1848                inference_time,
1849                behavior_dict,
1850            ) = self._make_prediction(
1851                prediction_name=suggestions_name,
1852                episode_names=suggestion_episodes,
1853                load_epochs=suggestion_load_epoch,
1854                parameters_update=parameters_update,
1855                data_path=data_path,
1856                file_paths=file_paths,
1857                mode=mode,
1858                task=task,
1859            )
1860        elif suggestions_prediction is not None:
1861            with open(
1862                os.path.join(
1863                    self.project_path,
1864                    "results",
1865                    "predictions",
1866                    f"{suggestions_prediction}.pickle",
1867                ),
1868                "rb",
1869            ) as f:
1870                predicted_classes = pickle.load(f)
1871            if parameters_update is None:
1872                parameters_update = {}
1873            parameters_update = self._update(
1874                self._predictions().load_parameters(suggestions_prediction),
1875                parameters_update,
1876            )
1877            parameters_update.pop("model")
1878            if suggestion_episodes is None:
1879                suggestion_episodes = [
1880                    os.path.basename(
1881                        os.path.dirname(
1882                            parameters_update["training"]["checkpoint_path"]
1883                        )
1884                    )
1885                ]
1886            task, parameters, mode = self._make_task_prediction(
1887                "_",
1888                load_episode=None,
1889                parameters_update=parameters_update,
1890                data_path=data_path,
1891                file_paths=file_paths,
1892                mode=mode,
1893            )
1894        else:
1895            predicted_classes = None
1896
1897        if len(suggestion_classes) > 0 and predicted_classes is not None:
1898            suggestions = self._make_suggestions(
1899                task,
1900                predicted_error,
1901                predicted_classes,
1902                suggestion_threshold,
1903                suggestion_threshold_diff,
1904                suggestion_hysteresis,
1905                suggestion_episodes,
1906                suggestion_classes,
1907                error_threshold,
1908                min_frames_suggestion,
1909                min_frames_al,
1910                visibility_min_score,
1911                visibility_min_frac,
1912                cut_annotated=cut_annotated,
1913            )
1914            videos = list(suggestions.keys())
1915            for v_id in videos:
1916                times_dict = defaultdict(lambda: defaultdict(lambda: []))
1917                clips = set()
1918                for c in suggestions[v_id]:
1919                    for start, end, ind in suggestions[v_id][c]:
1920                        times_dict[ind][c].append([start, end, 2])
1921                        clips.add(ind)
1922                clips = list(clips)
1923                times_dict = dict(times_dict)
1924                times = [
1925                    [times_dict[ind][c] for c in suggestion_classes] for ind in clips
1926                ]
1927                save_path = self._suggestion_path(v_id, suggestions_name)
1928                Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1929                with open(save_path, "wb") as f:
1930                    pickle.dump((None, suggestion_classes, clips, times), f)
1931
1932        if (
1933            error_episode is not None
1934            or len(exclude_classes) > 0
1935            or len(include_classes) > 0
1936        ):
1937            al_points = self._make_al_points(
1938                task,
1939                predicted_error,
1940                predicted_classes,
1941                exclude_classes,
1942                exclude_threshold,
1943                exclude_threshold_diff,
1944                exclude_hysteresis,
1945                include_classes,
1946                include_threshold,
1947                include_threshold_diff,
1948                include_hysteresis,
1949                error_episode,
1950                error_class,
1951                suggestion_episodes,
1952                error_threshold,
1953                error_threshold_diff,
1954                error_hysteresis,
1955                min_frames_al,
1956                visibility_min_score,
1957                visibility_min_frac,
1958            )
1959        else:
1960            al_points = self._make_al_points_from_suggestions(
1961                suggestions_name,
1962                task,
1963                predicted_classes,
1964                background_threshold,
1965                visibility_min_score,
1966                visibility_min_frac,
1967                num_behaviors=len(task.behaviors_dict()),
1968            )
1969        save_path = self._al_points_path(suggestions_name)
1970        Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1971        with open(save_path, "wb") as f:
1972            pickle.dump(al_points, f)
1973
1974        meta_parameters = {
1975            "error_episode": error_episode,
1976            "error_load_epoch": error_load_epoch,
1977            "error_class": error_class,
1978            "suggestion_episode": suggestion_episodes,
1979            "suggestion_load_epoch": suggestion_load_epoch,
1980            "suggestion_classes": suggestion_classes,
1981            "error_threshold": error_threshold,
1982            "error_threshold_diff": error_threshold_diff,
1983            "error_hysteresis": error_hysteresis,
1984            "suggestion_threshold": suggestion_threshold,
1985            "suggestion_threshold_diff": suggestion_threshold_diff,
1986            "suggestion_hysteresis": suggestion_hysteresis,
1987            "min_frames_suggestion": min_frames_suggestion,
1988            "min_frames_al": min_frames_al,
1989            "visibility_min_score": visibility_min_score,
1990            "visibility_min_frac": visibility_min_frac,
1991            "augment_n": augment_n,
1992            "exclude_classes": exclude_classes,
1993            "exclude_threshold": exclude_threshold,
1994            "exclude_threshold_diff": exclude_threshold_diff,
1995            "exclude_hysteresis": exclude_hysteresis,
1996        }
1997        self._save_suggestions(suggestions_name, {}, meta_parameters)
1998        if data_path is not None or file_paths is not None or remove_saved_features:
1999            self._remove_stores(parameters)
2000        print(f"\n")

Create active learning and suggestion files.

Generate predictions with the error and suggestion model and use them to create suggestion files for the labeling interface. Those files will render as suggested labels at intervals with high pose estimation quality. Quality here is defined by probability of error (predicted by the error model) and visibility parameters.

If error_episode or exclude_classes is not None, an active learning file will be created as well (with frames with high predicted probability of classes from exclude_classes and/or errors excluded from the active learning intervals).

In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals) you can apply one of three methods.

  • Simple threshold

    Set the hysteresis parameter (e.g. error_hysteresis) to False and the threshold parameter to $lpha$. In this case if the probability of a label is predicted to be higher than $lpha$ the frame will be considered labeled.

  • Hysteresis threshold

    Set the hysteresis parameter (e.g. error_hysteresis) to True, the threshold parameter to $lpha$ and the threshold_diff parameter to $eta$. Now intervals will be marked with a label if the probability of that label for all frames is higher than $lpha - eta$ and at least for one frame in that interval it is higher than $lpha$.

  • Max hysteresis threshold

    Set the hysteresis parameter (e.g. error_hysteresis) to True, the threshold parameter to $lpha$ and the threshold_diff parameter to None. With this combination intervals are marked with a label if that label is more likely than any other for all frames in this interval and at for at least one of those frames its probability is higher than $lpha$.

Parameters

suggestions_name : str the name of the suggestions error_episode : str, optional the name of the episode where the error model should be loaded from error_load_epoch : int, optional the epoch the error model should be loaded from error_class : str, optional the name of the error class (in error_episode) suggestions_prediction : str, optional the name of the predictions that should be used for the suggestion model suggestion_episodes : list, optional the names of the episodes where the suggestion models should be loaded from suggestion_load_epoch : int, optional the epoch the suggestion model should be loaded from suggestion_classes : list, optional a list of string names of the classes that should be suggested (in suggestion_episode) error_threshold : float, default 0.5 the hard threshold for error prediction error_threshold_diff : float, default 0.1 the difference between soft and hard thresholds for error prediction (in case hysteresis is used) error_hysteresis : bool, default False if True, hysteresis is used for error prediction suggestion_threshold : float | list, default 0.5 the hard threshold for class prediction (use a list to set different rules for different classes) suggestion_threshold_diff : float | list, default 0.1 the difference between soft and hard thresholds for class prediction (in case hysteresis is used; use a list to set different rules for different classes) suggestion_hysteresis : bool | list, default True if True, hysteresis is used for class prediction (use a list to set different rules for different classes) min_frames_suggestion : int, default 10 only actions longer than this number of frames will be suggested min_frames_al : int, default 30 only active learning intervals longer than this number of frames will be suggested visibility_min_score : float, default 0 the minimum visibility score for visibility filtering visibility_min_frac : float, default 0.7 the minimum fraction of visible frames for visibility filtering augment_n : int, default 10 the number of augmentations to average the predictions over exclude_classes : list, optional a list of string names of classes that should be excluded from the active learning intervals exclude_threshold : float | list, default 0.6 the hard threshold for excluded class prediction (use a list to set different rules for different classes) exclude_threshold_diff : float | list, default 0.1 the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used) exclude_hysteresis : bool | list, default False if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes) include_classes : list, optional a list of string names of classes that should be included into the active learning intervals include_threshold : float | list, default 0.6 the hard threshold for included class prediction (use a list to set different rules for different classes) include_threshold_diff : float | list, default 0.1 the difference between soft and hard thresholds for included class prediction (in case hysteresis is used) include_hysteresis : bool | list, default False if True, hysteresis is used for included class prediction (use a list to set different rules for different classes) data_path : str, optional the data path to run the prediction for file_paths : set, optional a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for parameters_update : dict, optional the parameters update dictionary mode : {'all', 'test', 'val', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) force : bool, default False if True and an episode with name episode_name already exists, it will be overwritten (use with caution!) remove_saved_features : bool, default False if True, the dataset will be deleted. cut_annotated : bool, default False if True, annotated frames will be cut from the suggestions background_threshold : float, default 0.5 the threshold for background prediction

def suggest_intervals_with_similarity( self, suggestions_name: str, prediction_name: str, target_video_id: str, target_clip: str, target_start: int, target_end: int, min_length: int = 60, n_intervals: int = 5, force: bool = False):
2095    def suggest_intervals_with_similarity(
2096        self,
2097        suggestions_name: str,
2098        prediction_name: str,
2099        target_video_id: str,
2100        target_clip: str,
2101        target_start: int,
2102        target_end: int,
2103        min_length: int = 60,
2104        n_intervals: int = 5,
2105        force: bool = False,
2106    ):
2107        """
2108        Suggest intervals based on similarity to a target interval.
2109
2110        Parameters
2111        ----------
2112        suggestions_name : str
2113            Name of the suggestion.
2114        prediction_name : str
2115            Name of the prediction to use.
2116        target_video_id : str
2117            Video id of the target interval.
2118        target_clip : str
2119            Clip id of the target interval.
2120        target_start : int
2121            Start frame of the target interval.
2122        target_end : int
2123            End frame of the target interval.
2124        min_length : int, default 60
2125            Minimum length of the suggested intervals.
2126        n_intervals : int, default 5
2127            Number of suggested intervals.
2128        force : bool, default False
2129            If True, the suggestion is overwritten if it already exists.
2130
2131        """
2132        self._check_suggestions_validity(suggestions_name, force=force)
2133        print(f"SUGGESTION {suggestions_name}")
2134        score_dict = self._generate_similarity_score(
2135            prediction_name, target_video_id, target_clip, target_start, target_end
2136        )
2137        intervals = self._suggest_intervals_from_dict(
2138            score_dict, min_length, n_intervals
2139        )
2140        suggestions_path = os.path.join(
2141            self.project_path,
2142            "results",
2143            "suggestions",
2144            suggestions_name,
2145        )
2146        if not os.path.exists(suggestions_path):
2147            os.mkdir(suggestions_path)
2148        with open(
2149            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2150        ) as f:
2151            pickle.dump(intervals, f)
2152        meta_parameters = {
2153            "prediction_name": prediction_name,
2154            "min_frames_suggestion": min_length,
2155            "n_intervals": n_intervals,
2156            "target_clip": target_clip,
2157            "target_start": target_start,
2158            "target_end": target_end,
2159        }
2160        self._save_suggestions(suggestions_name, {}, meta_parameters)
2161        print("\n")

Suggest intervals based on similarity to a target interval.

Parameters

suggestions_name : str Name of the suggestion. prediction_name : str Name of the prediction to use. target_video_id : str Video id of the target interval. target_clip : str Clip id of the target interval. target_start : int Start frame of the target interval. target_end : int End frame of the target interval. min_length : int, default 60 Minimum length of the suggested intervals. n_intervals : int, default 5 Number of suggested intervals. force : bool, default False If True, the suggestion is overwritten if it already exists.

def suggest_intervals_with_uncertainty( self, suggestions_name: str, episode_names: List, load_epochs: List = None, classes: List = None, n_frames: int = 10000, method: str = 'least_confidence', min_length: int = 60, augment_n: int = 0, data_path: str = None, file_paths: Set = None, parameters_update: Dict = None, mode: str = 'all', force: bool = False, remove_saved_features: bool = False) -> None:
2163    def suggest_intervals_with_uncertainty(
2164        self,
2165        suggestions_name: str,
2166        episode_names: List,
2167        load_epochs: List = None,
2168        classes: List = None,
2169        n_frames: int = 10000,
2170        method: str = "least_confidence",
2171        min_length: int = 60,
2172        augment_n: int = 0,
2173        data_path: str = None,
2174        file_paths: Set = None,
2175        parameters_update: Dict = None,
2176        mode: str = "all",
2177        force: bool = False,
2178        remove_saved_features: bool = False,
2179    ) -> None:
2180        """Generate an active learning file based on model uncertainty.
2181
2182        If you provide several episode names, the predicted probabilities will be averaged.
2183
2184        Parameters
2185        ----------
2186        suggestions_name : str
2187            the name of the suggestion
2188        episode_names : list
2189            a list of string episode names to load the models from
2190        load_epochs : list, optional
2191            a list of epoch indices to load the models from (if `None`, the last ones will be used)
2192        classes : list, optional
2193            a list of classes to look at (by default all)
2194        n_frames : int, default 10000
2195            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2196            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2197            with the set parameters)
2198        method : {"least_confidence", "entropy"}
2199            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
2200            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
2201        min_length : int, default 60
2202            the minimum number of frames in one interval
2203        augment_n : int, default 0
2204            the number of augmentations to average the predictions over
2205        data_path : str, optional
2206            the path to a data folder (by default, the project data is used)
2207        file_paths : set, optional
2208            a list of file paths (by default, the project data is used)
2209        parameters_update : dict, optional
2210            a dictionary of parameter updates
2211        mode : {"test", "val", "train", "all"}
2212            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2213            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2214        force : bool, default False
2215            if `True`, existing suggestions with the same name will be overwritten
2216        remove_saved_features : bool, default False
2217            if `True`, the dataset will be deleted after the computation
2218
2219        """
2220        self._check_suggestions_validity(suggestions_name, force=force)
2221        print(f"SUGGESTION {suggestions_name}")
2222        task, parameters, mode, predicted, inference_time, behavior_dict = (
2223            self._make_prediction(
2224                suggestions_name,
2225                episode_names,
2226                load_epochs,
2227                parameters_update,
2228                data_path=data_path,
2229                file_paths=file_paths,
2230                mode=mode,
2231                augment_n=augment_n,
2232                evaluate=False,
2233            )
2234        )
2235        if classes is None:
2236            classes = behavior_dict.values()
2237        episode = self._episodes().get_runs(episode_names[0])[0]
2238        score_tensors = task.generate_uncertainty_score(
2239            classes,
2240            augment_n,
2241            method,
2242            predicted,
2243            self._episode(episode).get_behaviors_dict(),
2244        )
2245        intervals = self._suggest_intervals(
2246            task.dataset(mode), score_tensors, n_frames, min_length
2247        )
2248        for k, v in intervals.items():
2249            l = sum([x[1] - x[0] for x in v])
2250            print(f"{k}: {len(v)} ({l})")
2251        if remove_saved_features:
2252            self._remove_stores(parameters)
2253        suggestions_path = os.path.join(
2254            self.project_path,
2255            "results",
2256            "suggestions",
2257            suggestions_name,
2258        )
2259        if not os.path.exists(suggestions_path):
2260            os.mkdir(suggestions_path)
2261        with open(
2262            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2263        ) as f:
2264            pickle.dump(intervals, f)
2265        meta_parameters = {
2266            "suggestion_episode": episode_names,
2267            "suggestion_load_epoch": load_epochs,
2268            "suggestion_classes": classes,
2269            "min_frames_suggestion": min_length,
2270            "augment_n": augment_n,
2271            "method": method,
2272            "num_frames": n_frames,
2273        }
2274        self._save_suggestions(suggestions_name, {}, meta_parameters)
2275        print("\n")

Generate an active learning file based on model uncertainty.

If you provide several episode names, the predicted probabilities will be averaged.

Parameters

suggestions_name : str the name of the suggestion episode_names : list a list of string episode names to load the models from load_epochs : list, optional a list of epoch indices to load the models from (if None, the last ones will be used) classes : list, optional a list of classes to look at (by default all) n_frames : int, default 10000 the threshold total number of frames in the suggested intervals (in the end result it will most likely be slightly larger; it will only be smaller if the algorithm fails to find enough intervals with the set parameters) method : {"least_confidence", "entropy"} the method used to calculate the scores from the probability predictions ("least_confidence": 1 - p_i if p_i > 0.5 or p_i if p_i < 0.5; "entropy": - p_i * log(p_i) - (1 - p_i) * log(1 - p_i)) min_length : int, default 60 the minimum number of frames in one interval augment_n : int, default 0 the number of augmentations to average the predictions over data_path : str, optional the path to a data folder (by default, the project data is used) file_paths : set, optional a list of file paths (by default, the project data is used) parameters_update : dict, optional a dictionary of parameter updates mode : {"test", "val", "train", "all"} the subset of the data to make the prediction for (forced to 'all' if data_path is not None; by default set to 'test' if the test subset if not empty, or to 'val' otherwise) force : bool, default False if True, existing suggestions with the same name will be overwritten remove_saved_features : bool, default False if True, the dataset will be deleted after the computation

def suggest_intervals_with_bald( self, suggestions_name: str, episode_name: str, load_epoch: int = None, classes: List = None, n_frames: int = 10000, num_models: int = 10, kernel_size: int = 11, min_length: int = 60, augment_n: int = 0, data_path: str = None, file_paths: Set = None, parameters_update: Dict = None, mode: str = 'all', force: bool = False, remove_saved_features: bool = False):
2277    def suggest_intervals_with_bald(
2278        self,
2279        suggestions_name: str,
2280        episode_name: str,
2281        load_epoch: int = None,
2282        classes: List = None,
2283        n_frames: int = 10000,
2284        num_models: int = 10,
2285        kernel_size: int = 11,
2286        min_length: int = 60,
2287        augment_n: int = 0,
2288        data_path: str = None,
2289        file_paths: Set = None,
2290        parameters_update: Dict = None,
2291        mode: str = "all",
2292        force: bool = False,
2293        remove_saved_features: bool = False,
2294    ):
2295        """Generate an active learning file based on Bayesian Active Learning by Disagreement.
2296
2297        Parameters
2298        ----------
2299        suggestions_name : str
2300            the name of the suggestion
2301        episode_name : str
2302            the name of the episode to load the model from
2303        load_epoch : int, optional
2304            the index of the epoch to load the model from (if `None`, the last one will be used)
2305        classes : list, optional
2306            a list of classes to look at (by default all)
2307        n_frames : int, default 10000
2308            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2309            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2310            with the set parameters)
2311        num_models : int, default 10
2312            the number of dropout masks to apply
2313        kernel_size : int, default 11
2314            the size of the smoothing kernel applied to the discrete results
2315        min_length : int, default 60
2316            the minimum number of frames in one interval
2317        augment_n : int, default 0
2318            the number of augmentations to average the predictions over
2319        data_path : str, optional
2320            the path to a data folder (by default, the project data is used)
2321        file_paths : set, optional
2322            a list of file paths (by default, the project data is used)
2323        parameters_update : dict, optional
2324            a dictionary of parameter updates
2325        mode : {"test", "val", "train", "all"}
2326            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2327            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2328        force : bool, default False
2329            if `True`, existing suggestions with the same name will be overwritten
2330        remove_saved_features : bool, default False
2331            if `True`, the dataset will be deleted after the computation
2332
2333        """
2334        self._check_suggestions_validity(suggestions_name, force=force)
2335        print(f"SUGGESTION {suggestions_name}")
2336        task, parameters, mode = self._make_task_prediction(
2337            suggestions_name,
2338            episode_name,
2339            parameters_update,
2340            load_epoch,
2341            data_path=data_path,
2342            file_paths=file_paths,
2343            mode=mode,
2344        )
2345        if classes is None:
2346            classes = list(task.behaviors_dict().values())
2347        score_tensors = task.generate_bald_score(
2348            classes, augment_n, num_models, kernel_size
2349        )
2350        intervals = self._suggest_intervals(
2351            task.dataset(mode), score_tensors, n_frames, min_length
2352        )
2353        if remove_saved_features:
2354            self._remove_stores(parameters)
2355        suggestions_path = os.path.join(
2356            self.project_path,
2357            "results",
2358            "suggestions",
2359            suggestions_name,
2360        )
2361        if not os.path.exists(suggestions_path):
2362            os.mkdir(suggestions_path)
2363        with open(
2364            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2365        ) as f:
2366            pickle.dump(intervals, f)
2367        meta_parameters = {
2368            "suggestion_episode": episode_name,
2369            "suggestion_load_epoch": load_epoch,
2370            "suggestion_classes": classes,
2371            "min_frames_suggestion": min_length,
2372            "augment_n": augment_n,
2373            "method": f"BALD:{num_models}",
2374            "num_frames": n_frames,
2375        }
2376        self._save_suggestions(suggestions_name, {}, meta_parameters)
2377        print("\n")

Generate an active learning file based on Bayesian Active Learning by Disagreement.

Parameters

suggestions_name : str the name of the suggestion episode_name : str the name of the episode to load the model from load_epoch : int, optional the index of the epoch to load the model from (if None, the last one will be used) classes : list, optional a list of classes to look at (by default all) n_frames : int, default 10000 the threshold total number of frames in the suggested intervals (in the end result it will most likely be slightly larger; it will only be smaller if the algorithm fails to find enough intervals with the set parameters) num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing kernel applied to the discrete results min_length : int, default 60 the minimum number of frames in one interval augment_n : int, default 0 the number of augmentations to average the predictions over data_path : str, optional the path to a data folder (by default, the project data is used) file_paths : set, optional a list of file paths (by default, the project data is used) parameters_update : dict, optional a dictionary of parameter updates mode : {"test", "val", "train", "all"} the subset of the data to make the prediction for (forced to 'all' if data_path is not None; by default set to 'test' if the test subset if not empty, or to 'val' otherwise) force : bool, default False if True, existing suggestions with the same name will be overwritten remove_saved_features : bool, default False if True, the dataset will be deleted after the computation

def list_episodes( self, episode_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.DataFrame:
2379    def list_episodes(
2380        self,
2381        episode_names: List = None,
2382        value_filter: str = "",
2383        display_parameters: List = None,
2384        print_results: bool = True,
2385    ) -> pd.DataFrame:
2386        """Get a filtered pandas dataframe with episode metadata.
2387
2388        Parameters
2389        ----------
2390        episode_names : list
2391            a list of strings of episode names
2392        value_filter : str
2393            a string of filters to apply; of this general structure:
2394            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
2395            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
2396        display_parameters : list
2397            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2398        print_results : bool, default True
2399            if True, the result will be printed to standard output
2400
2401        Returns
2402        -------
2403        pd.DataFrame
2404            the filtered dataframe
2405
2406        """
2407        episodes = self._episodes().list_episodes(
2408            episode_names, value_filter, display_parameters
2409        )
2410        if print_results:
2411            print("TRAINING EPISODES")
2412            print(episodes)
2413            print("\n")
2414        return episodes

Get a filtered pandas dataframe with episode metadata.

Parameters

episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1::(/<=/>=/=)value1,group_name2/par_name2::(/<=/>=/=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def list_predictions( self, episode_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.DataFrame:
2416    def list_predictions(
2417        self,
2418        episode_names: List = None,
2419        value_filter: str = "",
2420        display_parameters: List = None,
2421        print_results: bool = True,
2422    ) -> pd.DataFrame:
2423        """Get a filtered pandas dataframe with prediction metadata.
2424
2425        Parameters
2426        ----------
2427        episode_names : list
2428            a list of strings of episode names
2429        value_filter : str
2430            a string of filters to apply; of this general structure:
2431            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2432            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2433        display_parameters : list
2434            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2435        print_results : bool, default True
2436            if True, the result will be printed to standard output
2437
2438        Returns
2439        -------
2440        pd.DataFrame
2441            the filtered dataframe
2442
2443        """
2444        predictions = self._predictions().list_episodes(
2445            episode_names, value_filter, display_parameters
2446        )
2447        if print_results:
2448            print("PREDICTIONS")
2449            print(predictions)
2450            print("\n")
2451        return predictions

Get a filtered pandas dataframe with prediction metadata.

Parameters

episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def list_suggestions( self, suggestions_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.DataFrame:
2453    def list_suggestions(
2454        self,
2455        suggestions_names: List = None,
2456        value_filter: str = "",
2457        display_parameters: List = None,
2458        print_results: bool = True,
2459    ) -> pd.DataFrame:
2460        """Get a filtered pandas dataframe with prediction metadata.
2461
2462        Parameters
2463        ----------
2464        suggestions_names : list
2465            a list of strings of suggestion names
2466        value_filter : str
2467            a string of filters to apply; of this general structure:
2468            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2469            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2470        display_parameters : list
2471            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2472        print_results : bool, default True
2473            if True, the result will be printed to standard output
2474
2475        Returns
2476        -------
2477        pd.DataFrame
2478            the filtered dataframe
2479
2480        """
2481        suggestions = self._suggestions().list_episodes(
2482            suggestions_names, value_filter, display_parameters
2483        )
2484        if print_results:
2485            print("SUGGESTIONS")
2486            print(suggestions)
2487            print("\n")
2488        return suggestions

Get a filtered pandas dataframe with prediction metadata.

Parameters

suggestions_names : list a list of strings of suggestion names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def list_searches( self, search_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.DataFrame:
2490    def list_searches(
2491        self,
2492        search_names: List = None,
2493        value_filter: str = "",
2494        display_parameters: List = None,
2495        print_results: bool = True,
2496    ) -> pd.DataFrame:
2497        """Get a filtered pandas dataframe with hyperparameter search metadata.
2498
2499        Parameters
2500        ----------
2501        search_names : list
2502            a list of strings of search names
2503        value_filter : str
2504            a string of filters to apply; of this general structure:
2505            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2506            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2507        display_parameters : list
2508            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2509        print_results : bool, default True
2510            if True, the result will be printed to standard output
2511
2512        Returns
2513        -------
2514        pd.DataFrame
2515            the filtered dataframe
2516
2517        """
2518        searches = self._searches().list_episodes(
2519            search_names, value_filter, display_parameters
2520        )
2521        if print_results:
2522            print("SEARCHES")
2523            print(searches)
2524            print("\n")
2525        return searches

Get a filtered pandas dataframe with hyperparameter search metadata.

Parameters

search_names : list a list of strings of search names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def get_best_parameters(self, search_name: str, round_to_binary: List = None):
2527    def get_best_parameters(
2528        self,
2529        search_name: str,
2530        round_to_binary: List = None,
2531    ):
2532        """Get the best parameters found by a search.
2533
2534        Parameters
2535        ----------
2536        search_name : str
2537            the name of the search
2538        round_to_binary : list, default None
2539            a list of parameters to round to binary values
2540
2541        Returns
2542        -------
2543        best_params : dict
2544            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2545
2546        """
2547        params, model = self._searches().get_best_params(
2548            search_name, round_to_binary=round_to_binary
2549        )
2550        params = self._update(params, {"general": {"model_name": model}})
2551        return params

Get the best parameters found by a search.

Parameters

search_name : str the name of the search round_to_binary : list, default None a list of parameters to round to binary values

Returns

best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format

def list_best_parameters(self, search_name: str, print_results: bool = True) -> Dict:
2553    def list_best_parameters(
2554        self, search_name: str, print_results: bool = True
2555    ) -> Dict:
2556        """Get the raw dictionary of best parameters found by a search.
2557
2558        Parameters
2559        ----------
2560        search_name : str
2561            the name of the search
2562        print_results : bool, default True
2563            if True, the result will be printed to standard output
2564
2565        Returns
2566        -------
2567        best_params : dict
2568            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2569
2570        """
2571        params = self._searches().get_best_params_raw(search_name)
2572        if print_results:
2573            print(f"SEARCH RESULTS {search_name}")
2574            for k, v in params.items():
2575                print(f"{k}: {v}")
2576            print("\n")
2577        return params

Get the raw dictionary of best parameters found by a search.

Parameters

search_name : str the name of the search print_results : bool, default True if True, the result will be printed to standard output

Returns

best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format

def plot_episodes( self, episode_names: List, metrics: Union[List, str], modes: Union[List, str] = None, title: str = None, episode_labels: List = None, save_path: str = None, add_hlines: List = None, epoch_limits: List = None, colors: List = None, add_highpoint_hlines: bool = False, remove_box: bool = False, font_size: float = None, linewidth: float = None, return_ax: bool = False) -> None:
2579    def plot_episodes(
2580        self,
2581        episode_names: List,
2582        metrics: List | str,
2583        modes: List | str = None,
2584        title: str = None,
2585        episode_labels: List = None,
2586        save_path: str = None,
2587        add_hlines: List = None,
2588        epoch_limits: List = None,
2589        colors: List = None,
2590        add_highpoint_hlines: bool = False,
2591        remove_box: bool = False,
2592        font_size: float = None,
2593        linewidth: float = None,
2594        return_ax: bool = False,
2595    ) -> None:
2596        """Plot episode training curves.
2597
2598        Parameters
2599        ----------
2600        episode_names : list
2601            a list of episode names to plot; to plot to episodes in one line combine them in a list
2602            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
2603        metrics : list
2604            a list of metric to plot
2605        modes : list, optional
2606            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
2607        title : str, optional
2608            title for the plot
2609        episode_labels : list, optional
2610            a list of strings used to label the curves (has to be the same length as episode_names)
2611        save_path : str, optional
2612            the path to save the resulting plot
2613        add_hlines : list, optional
2614            a list of float values (or (value, label) tuples) to mark with horizontal lines
2615        epoch_limits : list, optional
2616            a list of (min, max) tuples to set the x-axis limits for each episode
2617        colors: list, optional
2618            a list of matplotlib colors
2619        add_highpoint_hlines : bool, default False
2620            if `True`, horizontal lines will be added at the highest value of each episode
2621        """
2622
2623        if isinstance(metrics, str):
2624            metrics = [metrics]
2625        if isinstance(modes, str):
2626            modes = [modes]
2627
2628        if font_size is not None:
2629            font = {"size": font_size}
2630            rc("font", **font)
2631        if modes is None:
2632            modes = ["val"]
2633        if add_hlines is None:
2634            add_hlines = []
2635        logs = []
2636        epochs = []
2637        labels = []
2638        if episode_labels is not None:
2639            assert len(episode_labels) == len(episode_names)
2640        for name_i, name in enumerate(episode_names):
2641            log_params = product(metrics, modes)
2642            for metric, mode in log_params:
2643                if episode_labels is not None:
2644                    label = episode_labels[name_i]
2645                else:
2646                    label = deepcopy(name)
2647                if len(modes) != 1:
2648                    label += f"_{mode}"
2649                if len(metrics) != 1:
2650                    label += f"_{metric}"
2651                labels.append(label)
2652                if isinstance(name, Iterable) and not isinstance(name, str):
2653                    epoch_list = defaultdict(lambda: [])
2654                    multi_logs = defaultdict(lambda: [])
2655                    for i, n in enumerate(name):
2656                        runs = self._episodes().get_runs(n)
2657                        if len(runs) > 1:
2658                            for run in runs:
2659                                if "::" in run:
2660                                    index = run.split("::")[-1]
2661                                else:
2662                                    index = run.split("#")[-1]
2663                                if multi_logs[index] == []:
2664                                    if multi_logs["null"] is None:
2665                                        raise RuntimeError(
2666                                            "The run indices are not consistent across episodes!"
2667                                        )
2668                                    else:
2669                                        multi_logs[index] += multi_logs["null"]
2670                                multi_logs[index] += list(
2671                                    self._episode(run).get_metric_log(mode, metric)
2672                                )
2673                                start = (
2674                                    0
2675                                    if len(epoch_list[index]) == 0
2676                                    else epoch_list[index][-1]
2677                                )
2678                                epoch_list[index] += [
2679                                    x + start
2680                                    for x in self._episode(run).get_epoch_list(mode)
2681                                ]
2682                            multi_logs["null"] = None
2683                        else:
2684                            if len(multi_logs.keys()) > 1:
2685                                raise RuntimeError(
2686                                    "Cannot plot a single-run episode after a multi-run episode!"
2687                                )
2688                            multi_logs["null"] += list(
2689                                self._episode(n).get_metric_log(mode, metric)
2690                            )
2691                            start = (
2692                                0
2693                                if len(epoch_list["null"]) == 0
2694                                else epoch_list["null"][-1]
2695                            )
2696                            epoch_list["null"] += [
2697                                x + start for x in self._episode(n).get_epoch_list(mode)
2698                            ]
2699                    if len(multi_logs.keys()) == 1:
2700                        log = multi_logs["null"]
2701                        epochs.append(epoch_list["null"])
2702                    else:
2703                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
2704                        epochs.append(
2705                            tuple([v for k, v in epoch_list.items() if k != "null"])
2706                        )
2707                else:
2708                    runs = self._episodes().get_runs(name)
2709                    if len(runs) > 1:
2710                        log = []
2711                        for run in runs:
2712                            tracked_metrics = self._episode(run).get_metrics()
2713                            if metric in tracked_metrics:
2714                                log.append(
2715                                    list(
2716                                        self._episode(run).get_metric_log(mode, metric)
2717                                    )
2718                                )
2719                            else:
2720                                relevant = []
2721                                for m in tracked_metrics:
2722                                    m_split = m.split("_")
2723                                    if (
2724                                        "_".join(m_split[:-1]) == metric
2725                                        and m_split[-1].isnumeric()
2726                                    ):
2727                                        relevant.append(m)
2728                                if len(relevant) == 0:
2729                                    raise ValueError(
2730                                        f"The {metric} metric was not tracked at {run}"
2731                                    )
2732                                arr = 0
2733                                for m in relevant:
2734                                    arr += self._episode(run).get_metric_log(mode, m)
2735                                arr /= len(relevant)
2736                                log.append(list(arr))
2737                        log = tuple(log)
2738                        epochs.append(
2739                            tuple(
2740                                [
2741                                    self._episode(run).get_epoch_list(mode)
2742                                    for run in runs
2743                                ]
2744                            )
2745                        )
2746                    else:
2747                        tracked_metrics = self._episode(name).get_metrics()
2748                        if metric in tracked_metrics:
2749                            log = list(self._episode(name).get_metric_log(mode, metric))
2750                        else:
2751                            relevant = []
2752                            for m in tracked_metrics:
2753                                m_split = m.split("_")
2754                                if (
2755                                    "_".join(m_split[:-1]) == metric
2756                                    and m_split[-1].isnumeric()
2757                                ):
2758                                    relevant.append(m)
2759                            if len(relevant) == 0:
2760                                raise ValueError(
2761                                    f"The {metric} metric was not tracked at {name}"
2762                                )
2763                            arr = 0
2764                            for m in relevant:
2765                                arr += self._episode(name).get_metric_log(mode, m)
2766                            arr /= len(relevant)
2767                            log = list(arr)
2768                        epochs.append(self._episode(name).get_epoch_list(mode))
2769                logs.append(log)
2770        # if episode_labels is not None:
2771        #     print(f'{len(episode_labels)=}, {len(logs)=}')
2772        #     if len(episode_labels) != len(logs):
2773
2774        #         raise ValueError(
2775        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
2776        #             f"curves ({len(logs)})!"
2777        #         )
2778        #     else:
2779        #         labels = episode_labels
2780        if colors is None:
2781            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
2782        if len(colors) != len(logs):
2783            raise ValueError(
2784                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
2785            )
2786        f, ax = plt.subplots()
2787        length = 0
2788        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
2789            if type(log) is list:
2790                if len(log) > length:
2791                    length = len(log)
2792                ax.plot(
2793                    epoch_list,
2794                    log,
2795                    label=label,
2796                    color=color,
2797                )
2798                if add_highpoint_hlines:
2799                    ax.axhline(np.max(log), linestyle="dashed", color=color)
2800            else:
2801                for l, xx in zip(log, epoch_list):
2802                    if len(l) > length:
2803                        length = len(l)
2804                    ax.plot(
2805                        xx,
2806                        l,
2807                        color=color,
2808                        alpha=0.2,
2809                    )
2810                if not all([len(x) == len(log[0]) for x in log]):
2811                    warnings.warn(
2812                        f"Got logs with unequal lengths in parallel runs for {label}"
2813                    )
2814                    log = list(log)
2815                    epoch_list = list(epoch_list)
2816                    for i, x in enumerate(epoch_list):
2817                        to_remove = []
2818                        for j, y in enumerate(x[1:]):
2819                            if y <= x[j - 1]:
2820                                y_ind = x.index(y)
2821                                to_remove += list(range(y_ind, j))
2822                        epoch_list[i] = [
2823                            y for j, y in enumerate(x) if j not in to_remove
2824                        ]
2825                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2826                    length = min([len(x) for x in log])
2827                    for i in range(len(log)):
2828                        log[i] = log[i][:length]
2829                        epoch_list[i] = epoch_list[i][:length]
2830                    if not all([x == epoch_list[0] for x in epoch_list]):
2831                        raise RuntimeError(
2832                            f"Got different epoch indices in parallel runs for {label}"
2833                        )
2834                mean = np.array(log).mean(0)
2835                ax.plot(
2836                    epoch_list[0],
2837                    mean,
2838                    label=label,
2839                    color=color,
2840                    linewidth=linewidth,
2841                )
2842                if add_highpoint_hlines:
2843                    ax.axhline(np.max(mean), linestyle="dashed", color=color)
2844        for x in add_hlines:
2845            label = None
2846            if isinstance(x, Iterable):
2847                x, label = x
2848            ax.axhline(x, label=label)
2849            ax.set_xlim((0, length))
2850
2851        ax.legend()
2852        ax.set_xlabel("epochs")
2853        if len(metrics) == 1:
2854            ax.set_ylabel(metrics[0])
2855        else:
2856            ax.set_ylabel("value")
2857        if title is None:
2858            if len(episode_names) == 1:
2859                title = episode_names[0]
2860            elif len(metrics) == 1:
2861                title = metrics[0]
2862        if epoch_limits is not None:
2863            ax.set_xlim(epoch_limits)
2864        if title is not None:
2865            ax.set_title(title)
2866        if remove_box:
2867            ax.box(False)
2868        if return_ax:
2869            return ax
2870        if save_path is not None:
2871            plt.savefig(save_path)
2872        plt.show()

Plot episode training curves.

Parameters

episode_names : list a list of episode names to plot; to plot to episodes in one line combine them in a list (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) metrics : list a list of metric to plot modes : list, optional a list of modes to plot ('train' and/or 'val'; ['val'] by default) title : str, optional title for the plot episode_labels : list, optional a list of strings used to label the curves (has to be the same length as episode_names) save_path : str, optional the path to save the resulting plot add_hlines : list, optional a list of float values (or (value, label) tuples) to mark with horizontal lines epoch_limits : list, optional a list of (min, max) tuples to set the x-axis limits for each episode colors: list, optional a list of matplotlib colors add_highpoint_hlines : bool, default False if True, horizontal lines will be added at the highest value of each episode

def update_parameters( self, parameters_update: Dict = None, load_search: str = None, load_parameters: List = None, round_to_binary: List = None) -> None:
2874    def update_parameters(
2875        self,
2876        parameters_update: Dict = None,
2877        load_search: str = None,
2878        load_parameters: List = None,
2879        round_to_binary: List = None,
2880    ) -> None:
2881        """Update the parameters in the project config files.
2882
2883        Parameters
2884        ----------
2885        parameters_update : dict, optional
2886            a dictionary of parameter updates
2887        load_search : str, optional
2888            the name of hyperparameter search results to load to config
2889        load_parameters : list, optional
2890            a list of lists of string names of the parameters to load from the searches
2891        round_to_binary : list, optional
2892            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2893
2894        """
2895        keys = [
2896            "general",
2897            "losses",
2898            "metrics",
2899            "ssl",
2900            "training",
2901            "data",
2902        ]
2903        parameters = self._read_parameters(catch_blanks=False)
2904        if parameters_update is not None:
2905            model_params = (
2906                parameters_update.pop("model") if "model" in parameters_update else None
2907            )
2908            feat_params = (
2909                parameters_update.pop("features")
2910                if "features" in parameters_update
2911                else None
2912            )
2913            aug_params = (
2914                parameters_update.pop("augmentations")
2915                if "augmentations" in parameters_update
2916                else None
2917            )
2918
2919            parameters = self._update(parameters, parameters_update)
2920            model_name = parameters["general"]["model_name"]
2921            parameters["model"] = self._open_yaml(
2922                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2923            )
2924            if model_params is not None:
2925                parameters["model"] = self._update(parameters["model"], model_params)
2926            feat_name = parameters["general"]["feature_extraction"]
2927            parameters["features"] = self._open_yaml(
2928                os.path.join(
2929                    self.project_path, "config", "features", f"{feat_name}.yaml"
2930                )
2931            )
2932            if feat_params is not None:
2933                parameters["features"] = self._update(
2934                    parameters["features"], feat_params
2935                )
2936            aug_name = options.extractor_to_transformer[
2937                parameters["general"]["feature_extraction"]
2938            ]
2939            parameters["augmentations"] = self._open_yaml(
2940                os.path.join(
2941                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2942                )
2943            )
2944            if aug_params is not None:
2945                parameters["augmentations"] = self._update(
2946                    parameters["augmentations"], aug_params
2947                )
2948        if load_search is not None:
2949            parameters_update, model_name = self._searches().get_best_params(
2950                load_search, load_parameters, round_to_binary
2951            )
2952            parameters["general"]["model_name"] = model_name
2953            parameters["model"] = self._open_yaml(
2954                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2955            )
2956            parameters = self._update(parameters, parameters_update)
2957        for key in keys:
2958            with open(
2959                os.path.join(self.project_path, "config", f"{key}.yaml"),
2960                "w",
2961                encoding="utf-8",
2962            ) as f:
2963                YAML().dump(parameters[key], f)
2964        model_name = parameters["general"]["model_name"]
2965        model_path = os.path.join(
2966            self.project_path, "config", "model", f"{model_name}.yaml"
2967        )
2968        with open(model_path, "w", encoding="utf-8") as f:
2969            YAML().dump(parameters["model"], f)
2970        features_name = parameters["general"]["feature_extraction"]
2971        features_path = os.path.join(
2972            self.project_path, "config", "features", f"{features_name}.yaml"
2973        )
2974        with open(features_path, "w", encoding="utf-8") as f:
2975            YAML().dump(parameters["features"], f)
2976        aug_name = options.extractor_to_transformer[features_name]
2977        aug_path = os.path.join(
2978            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2979        )
2980        with open(aug_path, "w", encoding="utf-8") as f:
2981            YAML().dump(parameters["augmentations"], f)

Update the parameters in the project config files.

Parameters

parameters_update : dict, optional a dictionary of parameter updates load_search : str, optional the name of hyperparameter search results to load to config load_parameters : list, optional a list of lists of string names of the parameters to load from the searches round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two

def get_summary( self, episode_names: list, method: str = 'last', average: int = 1, metrics: List = None, return_values: bool = False) -> Dict:
2983    def get_summary(
2984        self,
2985        episode_names: list,
2986        method: str = "last",
2987        average: int = 1,
2988        metrics: List = None,
2989        return_values: bool = False,
2990    ) -> Dict:
2991        """Get a summary of episode statistics.
2992
2993        If an episode has multiple runs, the statistics will be aggregated over all of them.
2994
2995        Parameters
2996        ----------
2997        episode_names : str
2998            the names of the episodes
2999        method : ["best", "last"]
3000            the method for choosing the epochs
3001        average : int, default 1
3002            the number of epochs to average over (for each run)
3003        metrics : list, optional
3004            a list of metrics
3005
3006        Returns
3007        -------
3008        statistics : dict
3009            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
3010            and 'std' for the standard deviation
3011
3012        """
3013        runs = []
3014        for episode_name in episode_names:
3015            runs_ep = self._episodes().get_runs(episode_name)
3016            if len(runs_ep) == 0:
3017                raise RuntimeError(
3018                    f"There is no {episode_name} episode in the project memory"
3019                )
3020            runs += runs_ep
3021        if metrics is None:
3022            metrics = self._episode(runs[0]).get_metrics()
3023
3024        values = {m: [] for m in metrics}
3025        for run in runs:
3026            for m in metrics:
3027                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
3028                if method == "best":
3029                    log = sorted(log)
3030                    values[m] += list(log[-average:])
3031                elif method == "last":
3032                    if len(log) == 0:
3033                        episodes = self._episodes().data
3034                        if average == 1 and ("results", m) in episodes.columns:
3035                            values[m] += [episodes.loc[run, ("results", m)]]
3036                        else:
3037                            raise RuntimeError(f"Did not find {m} metric for {run} run")
3038                    values[m] += list(log[-average:])
3039                elif method.startswith("epoch"):
3040                    epoch = int(method[5:]) - 1
3041                    pars = self._episodes().load_parameters(run)
3042                    step = int(pars["training"]["validation_interval"])
3043                    values[m] += [log[epoch // step]]
3044                else:
3045                    raise ValueError(
3046                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
3047                    )
3048        statistics = defaultdict(lambda: {})
3049        for m, v in values.items():
3050            statistics[m]["mean"] = np.mean(v)
3051            statistics[m]["std"] = np.std(v)
3052        print(f"SUMMARY {episode_names}")
3053        for m, v in statistics.items():
3054            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
3055        print("\n")
3056
3057        return (dict(statistics), values) if return_values else dict(statistics)

Get a summary of episode statistics.

If an episode has multiple runs, the statistics will be aggregated over all of them.

Parameters

episode_names : str the names of the episodes method : ["best", "last"] the method for choosing the epochs average : int, default 1 the number of epochs to average over (for each run) metrics : list, optional a list of metrics

Returns

statistics : dict a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean and 'std' for the standard deviation

@staticmethod
def remove_project(name: str, projects_path: str = None) -> None:
3059    @staticmethod
3060    def remove_project(name: str, projects_path: str = None) -> None:
3061        """Remove all project files and experiment records and results.
3062
3063        Parameters
3064        ----------
3065        name : str
3066            the name of the project to remove
3067        projects_path : str, optional
3068            the path to the projects directory (by default the home DLC2Action directory)
3069
3070        """
3071        if projects_path is None:
3072            projects_path = os.path.join(str(Path.home()), "DLC2Action")
3073        project_path = os.path.join(projects_path, name)
3074        if os.path.exists(project_path):
3075            shutil.rmtree(project_path)

Remove all project files and experiment records and results.

Parameters

name : str the name of the project to remove projects_path : str, optional the path to the projects directory (by default the home DLC2Action directory)

def remove_saved_features( self, dataset_names: List = None, exceptions: List = None, remove_active: bool = False) -> None:
3077    def remove_saved_features(
3078        self,
3079        dataset_names: List = None,
3080        exceptions: List = None,
3081        remove_active: bool = False,
3082    ) -> None:
3083        """Remove saved pre-computed dataset feature files.
3084
3085        By default, all features will be deleted.
3086        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
3087        while training or inference is happening though.
3088
3089        Parameters
3090        ----------
3091        dataset_names : list, optional
3092            a list of dataset names to delete (by default all names are added)
3093        exceptions : list, optional
3094            a list of dataset names to not be deleted
3095        remove_active : bool, default False
3096            if `False`, datasets used by unfinished episodes will not be deleted
3097
3098        """
3099        print("Removing datasets...")
3100        if dataset_names is None:
3101            dataset_names = []
3102        if exceptions is None:
3103            exceptions = []
3104        if not remove_active:
3105            exceptions += self._episodes().get_active_datasets()
3106        dataset_path = os.path.join(self.project_path, "saved_datasets")
3107        if os.path.exists(dataset_path):
3108            if dataset_names == []:
3109                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
3110
3111            to_remove = [
3112                x
3113                for x in dataset_names
3114                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
3115            ]
3116            if len(to_remove) > 2:
3117                to_remove = tqdm(to_remove)
3118            for dataset in to_remove:
3119                shutil.rmtree(os.path.join(dataset_path, dataset))
3120            to_remove = [
3121                f"{x}.pickle"
3122                for x in dataset_names
3123                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
3124                and x not in exceptions
3125            ]
3126            for dataset in to_remove:
3127                os.remove(os.path.join(dataset_path, dataset))
3128            names = self._saved_datasets().dataset_names()
3129            self._saved_datasets().remove(names)
3130        print("\n")

Remove saved pre-computed dataset feature files.

By default, all features will be deleted. No essential information can get lost, storing them only saves time. Be careful with deleting datasets while training or inference is happening though.

Parameters

dataset_names : list, optional a list of dataset names to delete (by default all names are added) exceptions : list, optional a list of dataset names to not be deleted remove_active : bool, default False if False, datasets used by unfinished episodes will not be deleted

def remove_extra_checkpoints(self, episode_names: List = None, exceptions: List = None) -> None:
3132    def remove_extra_checkpoints(
3133        self, episode_names: List = None, exceptions: List = None
3134    ) -> None:
3135        """Remove intermediate model checkpoint files (only leave the files for the last epoch).
3136
3137        By default, all intermediate checkpoints will be deleted.
3138        Files in the model folder that are not associated with any record in the meta files are also deleted.
3139
3140        Parameters
3141        ----------
3142        episode_names : list, optional
3143            a list of episode names to clean (by default all names are added)
3144        exceptions : list, optional
3145            a list of episode names to not clean
3146
3147        """
3148        model_path = os.path.join(self.project_path, "results", "model")
3149        try:
3150            all_names = self._episodes().data.index
3151        except:
3152            all_names = os.listdir(model_path)
3153        if episode_names is None:
3154            episode_names = all_names
3155        if exceptions is None:
3156            exceptions = []
3157        to_remove = [x for x in episode_names if x not in exceptions]
3158        folders = os.listdir(model_path)
3159        for folder in folders:
3160            if folder not in all_names:
3161                shutil.rmtree(os.path.join(model_path, folder))
3162            elif folder in to_remove:
3163                files = os.listdir(os.path.join(model_path, folder))
3164                for file in sorted(files)[:-1]:
3165                    os.remove(os.path.join(model_path, folder, file))

Remove intermediate model checkpoint files (only leave the files for the last epoch).

By default, all intermediate checkpoints will be deleted. Files in the model folder that are not associated with any record in the meta files are also deleted.

Parameters

episode_names : list, optional a list of episode names to clean (by default all names are added) exceptions : list, optional a list of episode names to not clean

def remove_suggestion(self, suggestion_name: str) -> None:
3181    def remove_suggestion(self, suggestion_name: str) -> None:
3182        """Remove a suggestion record.
3183
3184        Parameters
3185        ----------
3186        suggestion_name : str
3187            the name of the suggestion to remove
3188
3189        """
3190        self._suggestions().remove_episode(suggestion_name)
3191        suggestion_path = os.path.join(
3192            self.project_path, "results", "suggestions", suggestion_name
3193        )
3194        if os.path.exists(suggestion_path):
3195            shutil.rmtree(suggestion_path)

Remove a suggestion record.

Parameters

suggestion_name : str the name of the suggestion to remove

def remove_prediction(self, prediction_name: str) -> None:
3197    def remove_prediction(self, prediction_name: str) -> None:
3198        """Remove a prediction record.
3199
3200        Parameters
3201        ----------
3202        prediction_name : str
3203            the name of the prediction to remove
3204
3205        """
3206        self._predictions().remove_episode(prediction_name)
3207        prediction_path = self.prediction_path(prediction_name)
3208        if os.path.exists(prediction_path):
3209            shutil.rmtree(prediction_path)

Remove a prediction record.

Parameters

prediction_name : str the name of the prediction to remove

def check_prediction_exists(self, prediction_name: str) -> str | None:
3211    def check_prediction_exists(self, prediction_name: str) -> str | None:
3212        """Check if a prediction exists.
3213
3214        Parameters
3215        ----------
3216        prediction_name : str
3217            the name of the prediction to check
3218
3219        Returns
3220        -------
3221        str | None
3222            the path to the prediction if it exists, `None` otherwise
3223
3224        """
3225        prediction_path = self.prediction_path(prediction_name)
3226        if os.path.exists(prediction_path):
3227            return prediction_path
3228        return None

Check if a prediction exists.

Parameters

prediction_name : str the name of the prediction to check

Returns

str | None the path to the prediction if it exists, None otherwise

def remove_episode(self, episode_name: str) -> None:
3230    def remove_episode(self, episode_name: str) -> None:
3231        """Remove all model, logs and metafile records related to an episode.
3232
3233        Parameters
3234        ----------
3235        episode_name : str
3236            the name of the episode to remove
3237
3238        """
3239        runs = self._episodes().get_runs(episode_name)
3240        runs.append(episode_name)
3241        for run in runs:
3242            self._episodes().remove_episode(run)
3243            model_path = os.path.join(self.project_path, "results", "model", run)
3244            if os.path.exists(model_path):
3245                shutil.rmtree(model_path)
3246            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
3247            if os.path.exists(log_path):
3248                os.remove(log_path)

Remove all model, logs and metafile records related to an episode.

Parameters

episode_name : str the name of the episode to remove

def prune_unfinished(self, exceptions: List = None) -> List:
3273    def prune_unfinished(self, exceptions: List = None) -> List:
3274        """Remove all interrupted episodes.
3275
3276        Remove all episodes that either don't have a log file or have less epochs in the log file than in
3277        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
3278        currently running!
3279
3280        Parameters
3281        ----------
3282        exceptions : list
3283            the episodes to keep even if they are interrupted
3284
3285        Returns
3286        -------
3287        pruned : list
3288            a list of the episode names that were pruned
3289
3290        """
3291        if exceptions is None:
3292            exceptions = []
3293        unfinished = self._episodes().unfinished_episodes()
3294        unfinished = [x for x in unfinished if x not in exceptions]
3295        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
3296        unfinished += [
3297            x for x in model_folders if x not in self._episodes().list_episodes().index
3298        ]
3299        print(f"PRUNING {unfinished}")
3300        for episode_name in unfinished:
3301            self.remove_episode(episode_name)
3302        print(f"\n")
3303        return unfinished

Remove all interrupted episodes.

Remove all episodes that either don't have a log file or have less epochs in the log file than in the training parameters or have a model folder but not a record. Note that it can remove episodes that are currently running!

Parameters

exceptions : list the episodes to keep even if they are interrupted

Returns

pruned : list a list of the episode names that were pruned

def prediction_path(self, prediction_name: str) -> str:
3305    def prediction_path(self, prediction_name: str) -> str:
3306        """Get the path where prediction files are saved.
3307
3308        Parameters
3309        ----------
3310        prediction_name : str
3311            name of the prediction
3312
3313        Returns
3314        -------
3315        prediction_path : str
3316            the file path
3317
3318        """
3319        return os.path.join(
3320            self.project_path, "results", "predictions", f"{prediction_name}"
3321        )

Get the path where prediction files are saved.

Parameters

prediction_name : str name of the prediction

Returns

prediction_path : str the file path

def suggestion_path(self, suggestion_name: str) -> str:
3323    def suggestion_path(self, suggestion_name: str) -> str:
3324        """Get the path where suggestion files are saved.
3325
3326        Parameters
3327        ----------
3328        suggestion_name : str
3329            name of the prediction
3330
3331        Returns
3332        -------
3333        suggestion_path : str
3334            the file path
3335
3336        """
3337        return os.path.join(
3338            self.project_path, "results", "suggestions", f"{suggestion_name}"
3339        )

Get the path where suggestion files are saved.

Parameters

suggestion_name : str name of the prediction

Returns

suggestion_path : str the file path

@classmethod
def print_data_types(cls):
3341    @classmethod
3342    def print_data_types(cls):
3343        """Print available data types."""
3344        print("DATA TYPES:")
3345        for key, value in cls.data_types().items():
3346            print(f"{key}:")
3347            print(value.__doc__)

Print available data types.

@classmethod
def print_annotation_types(cls):
3349    @classmethod
3350    def print_annotation_types(cls):
3351        """Print available annotation types."""
3352        print("ANNOTATION TYPES:")
3353        for key, value in cls.annotation_types():
3354            print(f"{key}:")
3355            print(value.__doc__)

Print available annotation types.

@staticmethod
def data_types() -> List:
3357    @staticmethod
3358    def data_types() -> List:
3359        """Get available data types.
3360
3361        Returns
3362        -------
3363        data_types : list
3364            available data types
3365
3366        """
3367        return options.input_stores

Get available data types.

Returns

data_types : list available data types

@staticmethod
def annotation_types() -> List:
3369    @staticmethod
3370    def annotation_types() -> List:
3371        """Get available annotation types.
3372
3373        Returns
3374        -------
3375        list
3376            available annotation types
3377
3378        """
3379        return options.annotation_stores

Get available annotation types.

Returns

list available annotation types

def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3958    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3959        """Select the model and the metrics.
3960
3961        Parameters
3962        ----------
3963        model_name : str, optional
3964            model name; run `project.help("model") to find out more
3965        metric_names : list, optional
3966            a list of metric function names; run `project.help("metrics") to find out more
3967
3968        """
3969        pars = {"general": {}}
3970        if model_name is not None:
3971            assert model_name in options.models
3972            pars["general"]["model_name"] = model_name
3973        if metric_names is not None:
3974            for metric in metric_names:
3975                assert metric in options.metrics
3976            pars["general"]["metric_functions"] = metric_names
3977        self.update_parameters(pars)

Select the model and the metrics.

Parameters

model_name : str, optional model name; run project.help("model") to find out more metric_names : list, optional a list of metric function names; runproject.help("metrics") to find out more

def help(self, keyword: str = None):
3979    def help(self, keyword: str = None):
3980        """Get information on available options.
3981
3982        Parameters
3983        ----------
3984        keyword : str, optional
3985            the keyword for options (run without arguments to see which keywords are available)
3986
3987        """
3988        if keyword is None:
3989            print("AVAILABLE HELP FUNCTIONS:")
3990            print("- Try running `project.help(keyword)` with the following keywords:")
3991            print("    - model: to get more information on available models,")
3992            print(
3993                "    - features: to get more information on available feature extraction modes,"
3994            )
3995            print(
3996                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3997            )
3998            print("    - metrics: to see a list of available metric functions.")
3999            print("    - data: to see help for expected data structure")
4000            print(
4001                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
4002            )
4003            print(
4004                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
4005            )
4006            print(
4007                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
4008            )
4009        elif keyword == "model":
4010            print("MODELS:")
4011            for key, model in options.models.items():
4012                print(f"{key}:")
4013                print(model.__doc__)
4014        elif keyword == "features":
4015            print("FEATURE EXTRACTORS:")
4016            for key, extractor in options.feature_extractors.items():
4017                print(f"{key}:")
4018                print(extractor.__doc__)
4019        elif keyword == "partition_method":
4020            print("PARTITION METHODS:")
4021            print(
4022                BehaviorDataset.partition_train_test_val.__doc__.split(
4023                    "The partitioning method:"
4024                )[1].split("val_frac :")[0]
4025            )
4026        elif keyword == "metrics":
4027            print("METRICS:")
4028            for key, metric in options.metrics.items():
4029                print(f"{key}:")
4030                print(metric.__doc__)
4031        elif keyword == "data":
4032            print("DATA:")
4033            print(f"Video data: {self.data_type}")
4034            print(options.input_stores[self.data_type].__doc__)
4035            print(f"Annotation data: {self.annotation_type}")
4036            print(options.annotation_stores[self.annotation_type].__doc__)
4037            print(
4038                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
4039            )
4040        else:
4041            raise ValueError(f"The {keyword} keyword is not recognized")
4042        print("\n")

Get information on available options.

Parameters

keyword : str, optional the keyword for options (run without arguments to see which keywords are available)

def list_blanks(self, blanks=None):
4083    def list_blanks(self, blanks=None):
4084        """List parameters that need to be filled in.
4085
4086        Parameters
4087        ----------
4088        blanks : list, optional
4089            a list of the parameters to list, if already known
4090
4091        """
4092        if blanks is None:
4093            blanks = self._get_blanks()
4094        if len(blanks) > 0:
4095            to_update = defaultdict(lambda: [])
4096            for b, k, c in blanks:
4097                to_update[b].append((k, c))
4098            print("Before running experiments, please update all the blanks.")
4099            print("To do that, you can run this.")
4100            print("--------------------------------------------------------")
4101            print(f"project.update_parameters(")
4102            print(f"    {{")
4103            for big_key, keys in to_update.items():
4104                print(f'        "{big_key}": {{')
4105                for key, comment in keys:
4106                    print(f'            "{key}": ..., {comment}')
4107                print(f"        }}")
4108            print(f"    }}")
4109            print(")")
4110            print("--------------------------------------------------------")
4111            print("Replace ... with relevant values.")
4112        else:
4113            print("There is no blanks left!")

List parameters that need to be filled in.

Parameters

blanks : list, optional a list of the parameters to list, if already known

def list_basic_parameters(self):
4115    def list_basic_parameters(
4116        self,
4117    ):
4118        """Get a list of most relevant parameters and code to modify them."""
4119        parameters = self._read_parameters()
4120        print("BASIC PARAMETERS:")
4121        model_name = parameters["general"]["model_name"]
4122        metric_names = parameters["general"]["metric_functions"]
4123        loss_name = parameters["general"]["loss_function"]
4124        feature_extraction = parameters["general"]["feature_extraction"]
4125        print("Here is a list of current parameters.")
4126        print(
4127            "You can copy this code, change the parameters you want to set and run it to update the project config."
4128        )
4129        print("--------------------------------------------------------")
4130        print("project.update_parameters(")
4131        print("    {")
4132        for group in ["general", "data", "training"]:
4133            print(f'        "{group}": {{')
4134            for key in options.basic_parameters[group]:
4135                if key in parameters[group]:
4136                    print(
4137                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
4138                    )
4139            print("        },")
4140        print('        "losses": {')
4141        print(f'            "{loss_name}": {{')
4142        for key in options.basic_parameters["losses"][loss_name]:
4143            if key in parameters["losses"][loss_name]:
4144                print(
4145                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
4146                )
4147        print("            },")
4148        print("        },")
4149        print('        "metrics": {')
4150        for metric in metric_names:
4151            print(f'            "{metric}": {{')
4152            for key in parameters["metrics"][metric]:
4153                print(
4154                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
4155                )
4156            print("            },")
4157        print("        },")
4158        print('        "model": {')
4159        for key in options.basic_parameters["model"][model_name]:
4160            if key in parameters["model"]:
4161                print(
4162                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
4163                )
4164
4165        print("        },")
4166        print('        "features": {')
4167        for key in options.basic_parameters["features"][feature_extraction]:
4168            if key in parameters["features"]:
4169                print(
4170                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
4171                )
4172
4173        print("        },")
4174        print('        "augmentations": {')
4175        for key in options.basic_parameters["augmentations"][feature_extraction]:
4176            if key in parameters["augmentations"]:
4177                print(
4178                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
4179                )
4180        print("        },")
4181        print("    },")
4182        print(")")
4183        print("--------------------------------------------------------")
4184        print("\n")

Get a list of most relevant parameters and code to modify them.

def count_classes( self, load_episode: str = None, parameters_update: Dict = None, remove_saved_features: bool = False, bouts: bool = True) -> Dict:
5098    def count_classes(
5099        self,
5100        load_episode: str = None,
5101        parameters_update: Dict = None,
5102        remove_saved_features: bool = False,
5103        bouts: bool = True,
5104    ) -> Dict:
5105        """Get a dictionary of class counts in different modes.
5106
5107        Parameters
5108        ----------
5109        load_episode : str, optional
5110            the episode settings to load
5111        parameters_update : dict, optional
5112            a dictionary of parameter updates (only for "data" and "general" categories)
5113        remove_saved_features : bool, default False
5114            if `True`, the dataset that is used for computation is then deleted
5115        bouts : bool, default False
5116            if `True`, instead of frame counts segment counts are returned
5117
5118        Returns
5119        -------
5120        class_counts : dict
5121            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
5122            class names and values are class counts (in frames)
5123
5124        """
5125        if load_episode is None:
5126            task, parameters = self._make_task_training(
5127                episode_name="_", parameters_update=parameters_update, throwaway=True
5128            )
5129        else:
5130            task, parameters, _ = self._make_task_prediction(
5131                "_",
5132                load_episode=load_episode,
5133                parameters_update=parameters_update,
5134            )
5135        class_counts = task.count_classes(bouts=bouts)
5136        behaviors = task.behaviors_dict()
5137        class_counts = {
5138            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
5139            for kk, vv in class_counts.items()
5140        }
5141        if remove_saved_features:
5142            self._remove_stores(parameters)
5143        return class_counts

Get a dictionary of class counts in different modes.

Parameters

load_episode : str, optional the episode settings to load parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)

def plot_class_distribution( self, parameters_update: Dict = None, frame_cutoff: int = 1, bout_cutoff: int = 1, print_full: bool = False, remove_saved_features: bool = False, save: str = None) -> None:
5145    def plot_class_distribution(
5146        self,
5147        parameters_update: Dict = None,
5148        frame_cutoff: int = 1,
5149        bout_cutoff: int = 1,
5150        print_full: bool = False,
5151        remove_saved_features: bool = False,
5152        save: str = None,
5153    ) -> None:
5154        """Make a class distribution plot.
5155
5156        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
5157        is created or loaded for the computation with the default parameters).
5158
5159        Parameters
5160        ----------
5161        parameters_update : dict, optional
5162            a dictionary of parameter updates (only for "data" and "general" categories)
5163        frame_cutoff : int, default 1
5164            the minimum number of frames for a segment to be considered
5165        bout_cutoff : int, default 1
5166            the minimum number of bouts for a class to be considered
5167        print_full : bool, default False
5168            if `True`, the full class distribution is printed
5169        remove_saved_features : bool, default False
5170            if `True`, the dataset that is used for computation is then deleted
5171
5172        """
5173        task, parameters = self._make_task_training(
5174            episode_name="_", parameters_update=parameters_update, throwaway=True
5175        )
5176        cutoff = {True: bout_cutoff, False: frame_cutoff}
5177        for bouts in [True, False]:
5178            class_counts = task.count_classes(bouts=bouts)
5179            if print_full:
5180                print("Bouts:" if bouts else "Frames:")
5181                for k, v in class_counts.items():
5182                    if sum(v.values()) != 0:
5183                        print(f"  {k}:")
5184                        values, keys = zip(
5185                            *[
5186                                x
5187                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
5188                                if x[-1] != -100
5189                            ]
5190                        )
5191                        for kk, vv in zip(keys, values):
5192                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
5193            class_counts = {
5194                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
5195                for kk, vv in class_counts.items()
5196            }
5197            for key, d in class_counts.items():
5198                if sum(d.values()) != 0:
5199                    values, keys = zip(
5200                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
5201                    )
5202                    keys = [task.behaviors_dict()[x] for x in keys]
5203                    plt.bar(keys, values)
5204                    plt.title(key)
5205                    plt.xticks(rotation=45, ha="right")
5206                    if bouts:
5207                        plt.ylabel("bouts")
5208                    else:
5209                        plt.ylabel("frames")
5210                    plt.tight_layout()
5211
5212                    if save is None:
5213                        plt.savefig(save)
5214                        plt.close()
5215                    else:
5216                        plt.show()
5217        if remove_saved_features:
5218            self._remove_stores(parameters)

Make a class distribution plot.

You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset is created or loaded for the computation with the default parameters).

Parameters

parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) frame_cutoff : int, default 1 the minimum number of frames for a segment to be considered bout_cutoff : int, default 1 the minimum number of bouts for a class to be considered print_full : bool, default False if True, the full class distribution is printed remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted

def plot_confusion_matrix( self, episode_name: str, load_epoch: int = None, parameters_update: Dict = None, metric: str = 'recall', mode: str = 'val', remove_saved_features: bool = False, save_path: str = None, cmap: str = 'viridis') -> Tuple[numpy.ndarray, Iterable]:
5711    def plot_confusion_matrix(
5712        self,
5713        episode_name: str,
5714        load_epoch: int = None,
5715        parameters_update: Dict = None,
5716        metric: str = "recall",
5717        mode: str = "val",
5718        remove_saved_features: bool = False,
5719        save_path: str = None,
5720        cmap: str = "viridis",
5721    ) -> Tuple[ndarray, Iterable]:
5722        """Make a confusion matrix plot and return the data.
5723
5724        If the annotation is non-exclusive, only false positive labels are considered.
5725
5726        Parameters
5727        ----------
5728        episode_name : str
5729            the name of the episode to load
5730        load_epoch : int, optional
5731            the index of the epoch to load (by default the last one is loaded)
5732        parameters_update : dict, optional
5733            a dictionary of parameter updates (only for "data" and "general" categories)
5734        metric : {"recall", "precision"}
5735            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
5736            into account, and if `type` is `"precision"`, only false negatives
5737        mode : {'val', 'all', 'test', 'train'}
5738            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
5739        remove_saved_features : bool, default False
5740            if `True`, the dataset that is used for computation is then deleted
5741
5742        Returns
5743        -------
5744        confusion_matrix : np.ndarray
5745            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
5746            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
5747            `N_i` is the number of frames that have the i-th label in the ground truth
5748        classes : list
5749            a list of labels
5750
5751        """
5752        task, parameters, mode = self._make_task_prediction(
5753            "_",
5754            load_episode=episode_name,
5755            load_epoch=load_epoch,
5756            parameters_update=parameters_update,
5757            mode=mode,
5758        )
5759        dataset = task.dataset(mode)
5760        prediction = task.predict(dataset, raw_output=True)
5761        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
5762        if remove_saved_features:
5763            self._remove_stores(parameters)
5764        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
5765        ax.imshow(confusion_matrix, cmap=cmap)
5766        # Show all ticks and label them with the respective list entries
5767        ax.set_xticks(np.arange(len(classes)))
5768        ax.set_xticklabels(classes)
5769        ax.set_yticks(np.arange(len(classes)))
5770        ax.set_yticklabels(classes)
5771        # Rotate the tick labels and set their alignment.
5772        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
5773        # Loop over data dimensions and create text annotations.
5774        for i in range(len(classes)):
5775            for j in range(len(classes)):
5776                ax.text(
5777                    j,
5778                    i,
5779                    np.round(confusion_matrix[i, j], 2),
5780                    ha="center",
5781                    va="center",
5782                    color="w",
5783                )
5784        if metric is not None:
5785            ax.set_title(f"{metric} {episode_name}")
5786        else:
5787            ax.set_title(episode_name)
5788        fig.tight_layout()
5789        if save_path is None:
5790            plt.show()
5791        else:
5792            plt.savefig(save_path)
5793            plt.close()
5794        return confusion_matrix, classes

Make a confusion matrix plot and return the data.

If the annotation is non-exclusive, only false positive labels are considered.

Parameters

episode_name : str the name of the episode to load load_epoch : int, optional the index of the epoch to load (by default the last one is loaded) parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) metric : {"recall", "precision"} for datasets with non-exclusive annotation, if type is "recall", only false positives are taken into account, and if type is "precision", only false negatives mode : {'val', 'all', 'test', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted

Returns

confusion_matrix : np.ndarray a confusion matrix of shape (#classes, #classes) where A[i, j] = F_ij/N_i, F_ij is the number of frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, N_i is the number of frames that have the i-th label in the ground truth classes : list a list of labels

def plot_ethograms( self, episode_name: str, prediction_name: str, start: int = 0, end: int = -1, save_path: str = None, cmap_pred: str = 'binary', cmap_gt: str = 'binary', fontsize: int = 22, time_mode: str = 'frames', fps: int = None):
5884    def plot_ethograms(
5885        self,
5886        episode_name: str,
5887        prediction_name: str,
5888        start: int = 0,
5889        end: int = -1,
5890        save_path: str = None,
5891        cmap_pred: str = "binary",
5892        cmap_gt: str = "binary",
5893        fontsize: int = 22,
5894        time_mode: str = "frames",
5895        fps: int = None,
5896    ):
5897        """Plot ethograms from start to end time (in frames) for ground truth and prediction"""
5898        params = self._read_parameters(catch_blanks=False)
5899        parameters = self._get_data_pars(
5900            params,
5901        )
5902        if not save_path is None:
5903            os.makedirs(save_path, exist_ok=True)
5904        gt_files = [
5905            f for f in self.data_path if f.endswith(parameters["annotation_suffix"])
5906        ]
5907        pred_path = os.path.join(
5908            self.project_path, "results", "predictions", prediction_name
5909        )
5910        pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)]
5911        for pred_path in pred_paths:
5912            predictions = load_pickle(pred_path)
5913            behaviors = self.get_behavior_dictionary(episode_name)
5914            gt_filename = os.path.basename(pred_path).replace(
5915                "_".join(["_" + prediction_name, "prediction.pickle"]),
5916                parameters["annotation_suffix"],
5917            )
5918            if os.path.exists(os.path.join(self.data_path, gt_filename)):
5919                gt_data = load_pickle(os.path.join(self.data_path, gt_filename))
5920
5921                self._plot_ethograms_gt_pred(
5922                    gt_data,
5923                    predictions,
5924                    gt_data[1],
5925                    behaviors,
5926                    start=start,
5927                    end=end,
5928                    save=os.path.join(
5929                        save_path,
5930                        os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred",
5931                    ),
5932                    cmap_pred=cmap_pred,
5933                    cmap_gt=cmap_gt,
5934                    fontsize=fontsize,
5935                    time_mode=time_mode,
5936                    fps=fps,
5937                )
5938            else:
5939                print("GT file not found")

Plot ethograms from start to end time (in frames) for ground truth and prediction

def create_annotated_video( self, prediction_file_paths: list, video_file_paths: list, episode_name: str, ground_truth_file_paths: list = None, pred_thresh: float = 0.5, start: int = 0, end: int = -1):
6000    def create_annotated_video(
6001        self,
6002        prediction_file_paths: list,
6003        video_file_paths: list,
6004        episode_name: str,  # To get the list of behaviors
6005        ground_truth_file_paths: list = None,
6006        pred_thresh: float = 0.5,
6007        start: int = 0,
6008        end: int = -1,
6009    ):
6010        """Create a video with the predictions overlaid on the video"""
6011        for k, (pred_path, vid_path) in enumerate(
6012            zip(prediction_file_paths, video_file_paths)
6013        ):
6014            print("Generating video for :", os.path.basename(vid_path))
6015            predictions = load_pickle(pred_path)
6016            best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh
6017            behaviors = self.get_behavior_dictionary(episode_name)
6018            # Load video
6019            labels_pred = [behaviors[i] for i in range(len(behaviors))]
6020            labels_pred = np.roll(
6021                labels_pred, 1
6022            ).tolist() 
6023
6024            gt_data = None
6025            if ground_truth_file_paths is not None:
6026                gt_data = load_pickle(ground_truth_file_paths[k])
6027                labels_gt = gt_data[1]
6028                gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1])
6029
6030            cap = cv2.VideoCapture(vid_path)
6031            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6032            end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end
6033            fps = cap.get(cv2.CAP_PROP_FPS)
6034            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
6035            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
6036            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
6037            out = cv2.VideoWriter(
6038                os.path.join(
6039                    os.path.dirname(vid_path),
6040                    os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4",
6041                ),
6042                fourcc,
6043                fps,
6044                # (width + int(width/4) , height),
6045                (600, 300),
6046            )
6047            count = 0
6048            bar = tqdm(total=end - start)
6049            while cap.isOpened():
6050                ret, frame = cap.read()
6051                if not ret:
6052                    break
6053
6054                side_panel = self._create_side_panel(
6055                    height,
6056                    width,
6057                    labels_pred,
6058                    best_pred[:, count],
6059                    labels_gt,
6060                    gt_data[:, count],
6061                )
6062                frame = np.concatenate((frame, side_panel), axis=1)
6063                frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25)
6064                out.write(frame)
6065                count += 1
6066                bar.update(1)
6067
6068                if count > end:
6069                    break
6070
6071            cap.release()
6072            out.release()
6073            cv2.destroyAllWindows()

Create a video with the predictions overlaid on the video

def plot_predictions( self, episode_name: str, load_epoch: int = None, parameters_update: Dict = None, add_legend: bool = True, ground_truth: bool = True, colormap: str = 'dlc2action', hide_axes: bool = False, min_classes: int = 1, width: float = 10, whole_video: bool = False, transparent: bool = False, drop_classes: Set = None, search_classes: Set = None, num_plots: int = 1, remove_saved_features: bool = False, smooth_interval_prediction: int = 0, data_path: str = None, file_paths: Set = None, mode: str = 'val', font_size: float = None, window_size: int = 400) -> None:
6075    def plot_predictions(
6076        self,
6077        episode_name: str,
6078        load_epoch: int = None,
6079        parameters_update: Dict = None,
6080        add_legend: bool = True,
6081        ground_truth: bool = True,
6082        colormap: str = "dlc2action",
6083        hide_axes: bool = False,
6084        min_classes: int = 1,
6085        width: float = 10,
6086        whole_video: bool = False,
6087        transparent: bool = False,
6088        drop_classes: Set = None,
6089        search_classes: Set = None,
6090        num_plots: int = 1,
6091        remove_saved_features: bool = False,
6092        smooth_interval_prediction: int = 0,
6093        data_path: str = None,
6094        file_paths: Set = None,
6095        mode: str = "val",
6096        font_size: float = None,
6097        window_size: int = 400,
6098    ) -> None:
6099        """Visualize random predictions.
6100
6101        Parameters
6102        ----------
6103        episode_name : str
6104            the name of the episode to load
6105        load_epoch : int, optional
6106            the epoch to load (by default last)
6107        parameters_update : dict, optional
6108            parameter update dictionary
6109        add_legend : bool, default True
6110            if True, legend will be added to the plot
6111        ground_truth : bool, default True
6112            if True, ground truth will be added to the plot
6113        colormap : str, default 'Accent'
6114            the `matplotlib` colormap to use
6115        hide_axes : bool, default True
6116            if `True`, the axes will be hidden on the plot
6117        min_classes : int, default 1
6118            the minimum number of classes in a displayed interval
6119        width : float, default 10
6120            the width of the plot
6121        whole_video : bool, default False
6122            if `True`, whole videos are plotted instead of segments
6123        transparent : bool, default False
6124            if `True`, the background on the plot is transparent
6125        drop_classes : set, optional
6126            a set of class names to not be displayed
6127        search_classes : set, optional
6128            if given, only intervals where at least one of the classes is in ground truth will be shown
6129        num_plots : int, default 1
6130            the number of plots to make
6131        remove_saved_features : bool, default False
6132            if `True`, the dataset will be deleted after computation
6133        smooth_interval_prediction : int, default 0
6134            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
6135        data_path : str, optional
6136            the data path to run the prediction for
6137        file_paths : set, optional
6138            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
6139            for
6140        mode : {'all', 'test', 'val', 'train'}
6141            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
6142
6143        """
6144        plot_path = os.path.join(self.project_path, "results", "plots")
6145        task, parameters, mode = self._make_task_prediction(
6146            "_",
6147            load_episode=episode_name,
6148            parameters_update=parameters_update,
6149            load_epoch=load_epoch,
6150            data_path=data_path,
6151            file_paths=file_paths,
6152            mode=mode,
6153        )
6154        os.makedirs(plot_path, exist_ok=True)
6155        task.visualize_results(
6156            save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"),
6157            add_legend=add_legend,
6158            ground_truth=ground_truth,
6159            colormap=colormap,
6160            hide_axes=hide_axes,
6161            min_classes=min_classes,
6162            whole_video=whole_video,
6163            transparent=transparent,
6164            dataset=mode,
6165            drop_classes=drop_classes,
6166            search_classes=search_classes,
6167            width=width,
6168            smooth_interval_prediction=smooth_interval_prediction,
6169            font_size=font_size,
6170            num_plots=num_plots,
6171            window_size=window_size,
6172        )
6173        if remove_saved_features:
6174            self._remove_stores(parameters)

Visualize random predictions.

Parameters

episode_name : str the name of the episode to load load_epoch : int, optional the epoch to load (by default last) parameters_update : dict, optional parameter update dictionary add_legend : bool, default True if True, legend will be added to the plot ground_truth : bool, default True if True, ground truth will be added to the plot colormap : str, default 'Accent' the matplotlib colormap to use hide_axes : bool, default True if True, the axes will be hidden on the plot min_classes : int, default 1 the minimum number of classes in a displayed interval width : float, default 10 the width of the plot whole_video : bool, default False if True, whole videos are plotted instead of segments transparent : bool, default False if True, the background on the plot is transparent drop_classes : set, optional a set of class names to not be displayed search_classes : set, optional if given, only intervals where at least one of the classes is in ground truth will be shown num_plots : int, default 1 the number of plots to make remove_saved_features : bool, default False if True, the dataset will be deleted after computation smooth_interval_prediction : int, default 0 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) data_path : str, optional the data path to run the prediction for file_paths : set, optional a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for mode : {'all', 'test', 'val', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None)

def create_video_from_labels( self, video_dir_path: str, mode='ground_truth', prediction_name: str = None, save_path: str = None):
6176    def create_video_from_labels(
6177        self,
6178        video_dir_path: str,
6179        mode="ground_truth",
6180        prediction_name: str = None,
6181        save_path: str = None,
6182    ):
6183        if save_path is None:
6184            save_path = os.path.join(
6185                self.project_path, "results", f"annotated_videos_from_{mode}"
6186            )
6187        os.makedirs(save_path, exist_ok=True)
6188
6189        params = self._read_parameters(catch_blanks=False)
6190
6191        if mode == "ground_truth":
6192            source_dir = self.annotation_path
6193            annotation_suffix = params["data"]["annotation_suffix"]
6194        elif mode == "prediction":
6195            assert (
6196                not prediction_name is None
6197            ), "Please provide a prediction name with mode 'prediction'"
6198            source_dir = os.path.join(
6199                self.project_path, "results", "predictions", prediction_name
6200            )
6201            annotation_suffix = f"_{prediction_name}_prediction.pickle"
6202
6203        video_annotation_pairs = [
6204            (
6205                os.path.join(video_dir_path, f),
6206                os.path.join(
6207                    source_dir, f.replace(f.split(".")[-1], annotation_suffix)
6208                ),
6209            )
6210            for f in os.listdir(video_dir_path)
6211            if os.path.exists(
6212                os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix))
6213            )
6214        ]
6215
6216        for video_file, annotation_file in tqdm(video_annotation_pairs):
6217            if not os.path.exists(video_file):
6218                print(f"Video file {video_file} does not exist, skipping.")
6219                continue
6220            if not os.path.exists(annotation_file):
6221                print(f"Annotation file {annotation_file} does not exist, skipping.")
6222                continue
6223
6224            if annotation_file.endswith(".pickle"):
6225                annotations = load_pickle(annotation_file)
6226            elif annotation_file.endswith(".csv"):
6227                annotations = pd.read_csv(annotation_file)
6228
6229            if mode == "ground_truth":
6230                behaviors = annotations[1]
6231                annot_data = annotations[3]
6232            elif mode == "predictions":
6233                behaviors = list(annotations["classes"].values())
6234                annot_data = [
6235                    annotations[key]
6236                    for key in annotations.keys()
6237                    if key not in ["classes", "min_frame", "max_frame"]
6238                ]
6239                if params["general"]["exclusive"]:
6240                    annot_data = [np.argmax(annot, axis=1) for annot in annot_data]
6241                    seqs = [
6242                        [
6243                            self._bin_array_to_sequences(annot, target_value=k)
6244                            for k in range(len(behaviors))
6245                        ]
6246                        for annot in annot_data
6247                    ]
6248                else:
6249                    annot_data = [np.where(annot > 0.5)[0] for annot in annot_data]
6250                    seqs = [
6251                        self._bin_array_to_sequences(annot, target_value=1)
6252                        for annot in annot_data
6253                    ]
6254                annotations = ["", "", seqs]
6255
6256            for individual in annotations[3]:
6257                for behavior in annotations[3][individual]:
6258                    intervals = annotations[3][individual][behavior]
6259                    self._extract_videos(
6260                        video_file,
6261                        intervals,
6262                        behavior,
6263                        individual,
6264                        save_path,
6265                        resolution=(640, 480),
6266                        fps=30,
6267                    )
def create_metadata_backup(self) -> None:
6328    def create_metadata_backup(self) -> None:
6329        """Create a copy of the meta files."""
6330        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6331        meta_path = os.path.join(self.project_path, "meta")
6332        if os.path.exists(meta_copy_path):
6333            shutil.rmtree(meta_copy_path)
6334        os.mkdir(meta_copy_path)
6335        for file in os.listdir(meta_path):
6336            if file == "backup":
6337                continue
6338            if os.path.isdir(os.path.join(meta_path, file)):
6339                continue
6340            shutil.copy(
6341                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
6342            )

Create a copy of the meta files.

def load_metadata_backup(self) -> None:
6344    def load_metadata_backup(self) -> None:
6345        """Load from previously created meta data backup (in case of corruption)."""
6346        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6347        meta_path = os.path.join(self.project_path, "meta")
6348        for file in os.listdir(meta_copy_path):
6349            shutil.copy(
6350                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
6351            )

Load from previously created meta data backup (in case of corruption).

def get_behavior_dictionary(self, episode_name: str) -> Dict:
6353    def get_behavior_dictionary(self, episode_name: str) -> Dict:
6354        """Get the behavior dictionary for an episode.
6355
6356        Parameters
6357        ----------
6358        episode_name : str
6359            the name of the episode
6360
6361        Returns
6362        -------
6363        behaviors_dictionary : dict
6364            a dictionary where keys are label indices and values are label names
6365
6366        """
6367        return self._episode(episode_name).get_behaviors_dict()

Get the behavior dictionary for an episode.

Parameters

episode_name : str the name of the episode

Returns

behaviors_dictionary : dict a dictionary where keys are label indices and values are label names

def import_episodes( self, episodes_directory: str, name_map: Dict = None, repeat_policy: str = 'error') -> None:
6369    def import_episodes(
6370        self,
6371        episodes_directory: str,
6372        name_map: Dict = None,
6373        repeat_policy: str = "error",
6374    ) -> None:
6375        """Import episodes exported with `Project.export_episodes`.
6376
6377        Parameters
6378        ----------
6379        episodes_directory : str
6380            the path to the exported episodes directory
6381        name_map : dict, optional
6382            a name change dictionary for the episodes: keys are old names, values are new names
6383        repeat_policy : {'error', 'skip', 'force'}, default 'error'
6384            the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates,
6385            'force' overwrites existing episodes
6386
6387        """
6388        if name_map is None:
6389            name_map = {}
6390        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
6391        to_remove = []
6392        import_string = "Imported episodes: "
6393        for episode_name in episodes.index:
6394            if episode_name in name_map:
6395                import_string += f"{episode_name} "
6396                episode_name = name_map[episode_name]
6397                import_string += f"({episode_name}), "
6398            else:
6399                import_string += f"{episode_name}, "
6400            try:
6401                self._check_episode_validity(episode_name, allow_doublecolon=True)
6402            except ValueError as e:
6403                if str(e).endswith("is already taken!"):
6404                    if repeat_policy == "skip":
6405                        to_remove.append(episode_name)
6406                    elif repeat_policy == "force":
6407                        self.remove_episode(episode_name)
6408                    elif repeat_policy == "error":
6409                        raise ValueError(
6410                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
6411                        )
6412                    else:
6413                        raise ValueError(
6414                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']"
6415                        )
6416        episodes = episodes.drop(index=to_remove)
6417        self._episodes().update(
6418            episodes,
6419            name_map=name_map,
6420            force=(repeat_policy == "force"),
6421            data_path=self.data_path,
6422            annotation_path=self.annotation_path,
6423        )
6424        for episode_name in episodes.index:
6425            if episode_name in name_map:
6426                new_episode_name = name_map[episode_name]
6427            else:
6428                new_episode_name = episode_name
6429            model_dir = os.path.join(
6430                self.project_path, "results", "model", new_episode_name
6431            )
6432            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
6433            if os.path.exists(model_dir):
6434                shutil.rmtree(model_dir)
6435            os.mkdir(model_dir)
6436            for file in os.listdir(old_model_dir):
6437                shutil.copyfile(
6438                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
6439                )
6440            log_file = os.path.join(
6441                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6442            )
6443            old_log_file = os.path.join(
6444                episodes_directory, "logs", f"{episode_name}.txt"
6445            )
6446            shutil.copyfile(old_log_file, log_file)
6447        print(import_string)
6448        print("\n")

Import episodes exported with Project.export_episodes.

Parameters

episodes_directory : str the path to the exported episodes directory name_map : dict, optional a name change dictionary for the episodes: keys are old names, values are new names repeat_policy : {'error', 'skip', 'force'}, default 'error' the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates, 'force' overwrites existing episodes

def export_episodes( self, episode_names: List, output_directory: str, name: str = None) -> None:
6450    def export_episodes(
6451        self, episode_names: List, output_directory: str, name: str = None
6452    ) -> None:
6453        """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`.
6454
6455        Parameters
6456        ----------
6457        episode_names : list
6458            a list of string episode names
6459        output_directory : str
6460            the path to the directory where the episodes will be saved
6461        name : str, optional
6462            the name of the episodes directory (by default `exported_episodes`)
6463
6464        """
6465        if name is None:
6466            name = "exported_episodes"
6467        if os.path.exists(
6468            os.path.join(output_directory, name + ".zip")
6469        ) or os.path.exists(os.path.join(output_directory, name)):
6470            i = 1
6471            while os.path.exists(
6472                os.path.join(output_directory, name + f"_{i}.zip")
6473            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
6474                i += 1
6475            name = name + f"_{i}"
6476        dest_dir = os.path.join(output_directory, name)
6477        os.mkdir(dest_dir)
6478        os.mkdir(os.path.join(dest_dir, "model"))
6479        os.mkdir(os.path.join(dest_dir, "logs"))
6480        runs = []
6481        for episode in episode_names:
6482            runs += self._episodes().get_runs(episode)
6483        for run in runs:
6484            shutil.copytree(
6485                os.path.join(self.project_path, "results", "model", run),
6486                os.path.join(dest_dir, "model", run),
6487            )
6488            shutil.copyfile(
6489                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
6490                os.path.join(dest_dir, "logs", f"{run}.txt"),
6491            )
6492        data = self._episodes().get_subset(runs)
6493        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))

Save selected episodes as a file that can be imported into another project with Project.import_episodes.

Parameters

episode_names : list a list of string episode names output_directory : str the path to the directory where the episodes will be saved name : str, optional the name of the episodes directory (by default exported_episodes)

def get_results_table( self, episode_names: List, metrics: List = None, mode: str = 'mean', print_results: bool = True, classes: List = None):
6495    def get_results_table(
6496        self,
6497        episode_names: List,
6498        metrics: List = None,
6499        mode: str = "mean",  # Choose between ["mean", "statistics", "detail"]
6500        print_results: bool = True,
6501        classes: List = None,
6502    ):
6503        """Generate a `pandas` dataframe with a summary of episode results.
6504
6505        Parameters
6506        ----------
6507        episode_names : list
6508            a list of names of episodes to include
6509        metrics : list, optional
6510            a list of metric names to include
6511        mode : bool, optional
6512            the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean"
6513        print_results : bool, optional
6514            if True, the results will be printed to the console, by default True
6515        classes : list, optional
6516            a list of names of classes to include (by default all are included)
6517
6518        Returns
6519        -------
6520        results : pd.DataFrame
6521            a table with the results
6522
6523        """
6524        run_names = []
6525        for episode in episode_names:
6526            run_names += self._episodes().get_runs(episode)
6527        episodes = self.list_episodes(run_names, print_results=False)
6528        metric_columns = [x for x in episodes.columns if x[0] == "results"]
6529        results_df = pd.DataFrame()
6530        if metrics is not None:
6531            metric_columns = [
6532                x for x in metric_columns if x[1].split("_")[0] in metrics
6533            ]
6534        for episode in episode_names:
6535            results = []
6536            metric_set = set()
6537            for run in self._episodes().get_runs(episode):
6538                beh_dict = self.get_behavior_dictionary(run)
6539                res_dict = defaultdict(lambda: {})
6540                for column in metric_columns:
6541                    if np.isnan(episodes.loc[run, column]):
6542                        continue
6543                    split = column[1].split("_")
6544                    if split[-1].isnumeric():
6545                        beh_ind = int(split[-1])
6546                        metric_name = "_".join(split[:-1])
6547                        beh = beh_dict[beh_ind]
6548                    else:
6549                        beh = "average"
6550                        metric_name = column[1]
6551                    res_dict[beh][metric_name] = episodes.loc[run, column]
6552                    metric_set.add(metric_name)
6553                if "average" not in res_dict:
6554                    res_dict["average"] = {}
6555                for metric in metric_set:
6556                    if metric not in res_dict["average"]:
6557                        arr = [
6558                            res_dict[beh][metric]
6559                            for beh in res_dict
6560                            if metric in res_dict[beh]
6561                        ]
6562                        res_dict["average"][metric] = np.mean(arr)
6563                results.append(res_dict)
6564            episode_results = {}
6565            for metric in metric_set:
6566                for beh in results[0].keys():
6567                    if classes is not None and beh not in classes:
6568                        continue
6569                    arr = []
6570                    for res_dict in results:
6571                        if metric in res_dict[beh]:
6572                            arr.append(res_dict[beh][metric])
6573                    if len(arr) > 0:
6574                        if mode == "statistics":
6575                            episode_results[(beh, f"{episode} {metric} mean")] = (
6576                                np.mean(arr)
6577                            )
6578                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
6579                                arr
6580                            )
6581                        elif mode == "mean":
6582                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
6583                        elif mode == "detail":
6584                            for i, val in enumerate(arr):
6585                                episode_results[(beh, f"{episode}::{i} {metric}")] = val
6586            for key, value in episode_results.items():
6587                results_df.loc[key[0], key[1]] = value
6588        if print_results:
6589            print(f"RESULTS:")
6590            print(results_df)
6591            print("\n")
6592        return results_df

Generate a pandas dataframe with a summary of episode results.

Parameters

episode_names : list a list of names of episodes to include metrics : list, optional a list of metric names to include mode : bool, optional the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean" print_results : bool, optional if True, the results will be printed to the console, by default True classes : list, optional a list of names of classes to include (by default all are included)

Returns

results : pd.DataFrame a table with the results

def episode_exists(self, episode_name: str) -> bool:
6594    def episode_exists(self, episode_name: str) -> bool:
6595        """Check if an episode already exists.
6596
6597        Parameters
6598        ----------
6599        episode_name : str
6600            the episode name
6601
6602        Returns
6603        -------
6604        exists : bool
6605            `True` if the episode exists
6606
6607        """
6608        return self._episodes().check_name_validity(episode_name)

Check if an episode already exists.

Parameters

episode_name : str the episode name

Returns

exists : bool True if the episode exists

def search_exists(self, search_name: str) -> bool:
6610    def search_exists(self, search_name: str) -> bool:
6611        """Check if a search already exists.
6612
6613        Parameters
6614        ----------
6615        search_name : str
6616            the search name
6617
6618        Returns
6619        -------
6620        exists : bool
6621            `True` if the search exists
6622
6623        """
6624        return self._searches().check_name_validity(search_name)

Check if a search already exists.

Parameters

search_name : str the search name

Returns

exists : bool True if the search exists

def prediction_exists(self, prediction_name: str) -> bool:
6626    def prediction_exists(self, prediction_name: str) -> bool:
6627        """Check if a prediction already exists.
6628
6629        Parameters
6630        ----------
6631        prediction_name : str
6632            the prediction name
6633
6634        Returns
6635        -------
6636        exists : bool
6637            `True` if the prediction exists
6638
6639        """
6640        return self._predictions().check_name_validity(prediction_name)

Check if a prediction already exists.

Parameters

prediction_name : str the prediction name

Returns

exists : bool True if the prediction exists

@staticmethod
def project_name_available(projects_path: str, project_name: str):
6642    @staticmethod
6643    def project_name_available(projects_path: str, project_name: str):
6644        """Check if a project name is available.
6645
6646        Parameters
6647        ----------
6648        projects_path : str
6649            the path to the projects directory
6650        project_name : str
6651            the name of the project to check
6652
6653        Returns
6654        -------
6655        available : bool
6656            `True` if the project name is available
6657
6658        """
6659        if projects_path is None:
6660            projects_path = os.path.join(str(Path.home()), "DLC2Action")
6661        return not os.path.exists(os.path.join(projects_path, project_name))

Check if a project name is available.

Parameters

projects_path : str the path to the projects directory project_name : str the name of the project to check

Returns

available : bool True if the project name is available

def rename_episode(self, episode_name: str, new_episode_name: str):
6676    def rename_episode(self, episode_name: str, new_episode_name: str):
6677        """Rename an episode.
6678
6679        Parameters
6680        ----------
6681        episode_name : str
6682            the current episode name
6683        new_episode_name : str
6684            the new episode name
6685
6686        """
6687        shutil.move(
6688            os.path.join(self.project_path, "results", "model", episode_name),
6689            os.path.join(self.project_path, "results", "model", new_episode_name),
6690        )
6691        shutil.move(
6692            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
6693            os.path.join(
6694                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6695            ),
6696        )
6697        self._episodes().rename_episode(episode_name, new_episode_name)

Rename an episode.

Parameters

episode_name : str the current episode name new_episode_name : str the new episode name